233 lines
5.9 KiB
Lua
Executable File
233 lines
5.9 KiB
Lua
Executable File
-- Copyright (C) Yichun Zhang (agentzh)
|
|
|
|
|
|
local base = require "resty.core.base"
|
|
base.allows_subsystem('http', 'stream')
|
|
|
|
|
|
local ffi = require "ffi"
|
|
local bit = require "bit"
|
|
local bor = bit.bor
|
|
local C = ffi.C
|
|
local ffi_str = ffi.string
|
|
local get_request = base.get_request
|
|
local error = error
|
|
local errmsg = base.get_errmsg_ptr()
|
|
local get_size_ptr = base.get_size_ptr
|
|
local FFI_OK = base.FFI_OK
|
|
local subsystem = ngx.config.subsystem
|
|
local ngx_phase = ngx.get_phase
|
|
local byte = string.byte
|
|
local lshift = bit.lshift
|
|
local table_insert = table.insert
|
|
|
|
|
|
local ngx_lua_ffi_ssl_get_client_hello_server_name
|
|
local ngx_lua_ffi_ssl_get_client_hello_ext
|
|
local ngx_lua_ffi_ssl_set_protocols
|
|
|
|
|
|
if subsystem == 'http' then
|
|
ffi.cdef[[
|
|
int ngx_http_lua_ffi_ssl_get_client_hello_server_name(ngx_http_request_t *r,
|
|
const char **name, size_t *namelen, char **err);
|
|
|
|
int ngx_http_lua_ffi_ssl_get_client_hello_ext(ngx_http_request_t *r,
|
|
unsigned int type, const unsigned char **out, size_t *outlen,
|
|
char **err);
|
|
|
|
int ngx_http_lua_ffi_ssl_set_protocols(ngx_http_request_t *r,
|
|
int protocols, char **err);
|
|
]]
|
|
|
|
ngx_lua_ffi_ssl_get_client_hello_server_name =
|
|
C.ngx_http_lua_ffi_ssl_get_client_hello_server_name
|
|
ngx_lua_ffi_ssl_get_client_hello_ext =
|
|
C.ngx_http_lua_ffi_ssl_get_client_hello_ext
|
|
ngx_lua_ffi_ssl_set_protocols = C.ngx_http_lua_ffi_ssl_set_protocols
|
|
|
|
elseif subsystem == 'stream' then
|
|
ffi.cdef[[
|
|
int ngx_stream_lua_ffi_ssl_get_client_hello_server_name(
|
|
ngx_stream_lua_request_t *r, const char **name, size_t *namelen,
|
|
char **err);
|
|
|
|
int ngx_stream_lua_ffi_ssl_get_client_hello_ext(
|
|
ngx_stream_lua_request_t *r, unsigned int type,
|
|
const unsigned char **out, size_t *outlen, char **err);
|
|
|
|
int ngx_stream_lua_ffi_ssl_set_protocols(ngx_stream_lua_request_t *r,
|
|
int protocols, char **err);
|
|
]]
|
|
|
|
ngx_lua_ffi_ssl_get_client_hello_server_name =
|
|
C.ngx_stream_lua_ffi_ssl_get_client_hello_server_name
|
|
ngx_lua_ffi_ssl_get_client_hello_ext =
|
|
C.ngx_stream_lua_ffi_ssl_get_client_hello_ext
|
|
ngx_lua_ffi_ssl_set_protocols = C.ngx_stream_lua_ffi_ssl_set_protocols
|
|
end
|
|
|
|
|
|
local _M = { version = base.version }
|
|
|
|
|
|
local ccharpp = ffi.new("const char*[1]")
|
|
local cucharpp = ffi.new("const unsigned char*[1]")
|
|
|
|
|
|
-- return server_name, err
|
|
function _M.get_client_hello_server_name()
|
|
local r = get_request()
|
|
if not r then
|
|
error("no request found")
|
|
end
|
|
|
|
if ngx_phase() ~= "ssl_client_hello" then
|
|
error("API disabled in the current context")
|
|
end
|
|
|
|
local sizep = get_size_ptr()
|
|
|
|
local rc = ngx_lua_ffi_ssl_get_client_hello_server_name(r, ccharpp, sizep,
|
|
errmsg)
|
|
if rc == FFI_OK then
|
|
return ffi_str(ccharpp[0], sizep[0])
|
|
end
|
|
|
|
-- NGX_DECLINED: no sni extension
|
|
if rc == -5 then
|
|
return nil
|
|
end
|
|
|
|
return nil, ffi_str(errmsg[0])
|
|
end
|
|
|
|
|
|
-- return ext, err
|
|
function _M.get_client_hello_ext(ext_type)
|
|
local r = get_request()
|
|
if not r then
|
|
error("no request found")
|
|
end
|
|
|
|
if ngx_phase() ~= "ssl_client_hello" then
|
|
error("API disabled in the current context")
|
|
end
|
|
|
|
local sizep = get_size_ptr()
|
|
|
|
local rc = ngx_lua_ffi_ssl_get_client_hello_ext(r, ext_type, cucharpp,
|
|
sizep, errmsg)
|
|
if rc == FFI_OK then
|
|
return ffi_str(cucharpp[0], sizep[0])
|
|
end
|
|
|
|
-- NGX_DECLINED: no extension
|
|
if rc == -5 then
|
|
return nil
|
|
end
|
|
|
|
return nil, ffi_str(errmsg[0])
|
|
end
|
|
|
|
-- tls.handshake.extension.type supported_version
|
|
local supported_versions_type = 43
|
|
local versions_map = {
|
|
[0x002] = "SSLv2",
|
|
[0x300] = "SSLv3",
|
|
[0x301] = "TLSv1",
|
|
[0x302] = "TLSv1.1",
|
|
[0x303] = "TLSv1.2",
|
|
[0x304] = "TLSv1.3",
|
|
}
|
|
|
|
-- return types, err
|
|
function _M.get_supported_versions()
|
|
local r = get_request()
|
|
if not r then
|
|
error("no request found")
|
|
end
|
|
|
|
if ngx_phase() ~= "ssl_client_hello" then
|
|
error("API disabled in the current context")
|
|
end
|
|
|
|
local sizep = get_size_ptr()
|
|
|
|
local rc = ngx_lua_ffi_ssl_get_client_hello_ext(r, supported_versions_type,
|
|
cucharpp, sizep, errmsg)
|
|
|
|
if rc ~= FFI_OK then
|
|
-- NGX_DECLINED: no extension
|
|
if rc == -5 then
|
|
return nil
|
|
end
|
|
|
|
return nil, ffi_str(errmsg[0])
|
|
end
|
|
|
|
local supported_versions_str = ffi_str(cucharpp[0], sizep[0])
|
|
local remain_len = #supported_versions_str
|
|
if remain_len == 0 then
|
|
return nil
|
|
end
|
|
|
|
local supported_versions_len = byte(supported_versions_str, 1)
|
|
remain_len = remain_len - 1
|
|
|
|
if remain_len ~= supported_versions_len then
|
|
return nil
|
|
end
|
|
local types = {}
|
|
while remain_len >= 2 do
|
|
local type_hi = byte(supported_versions_str, remain_len)
|
|
local type_lo = byte(supported_versions_str, remain_len + 1)
|
|
local type_id = lshift(type_hi, 8) + type_lo
|
|
if versions_map[type_id] ~= nil then
|
|
table_insert(types, versions_map[type_id])
|
|
end
|
|
remain_len = remain_len - 2
|
|
end
|
|
return types
|
|
end
|
|
|
|
|
|
local prot_map = {
|
|
["SSLv2"] = 0x0002,
|
|
["SSLv3"] = 0x0004,
|
|
["TLSv1"] = 0x0008,
|
|
["TLSv1.1"] = 0x0010,
|
|
["TLSv1.2"] = 0x0020,
|
|
["TLSv1.3"] = 0x0040
|
|
}
|
|
|
|
|
|
-- return ok, err
|
|
function _M.set_protocols(protocols)
|
|
local r = get_request()
|
|
if not r then
|
|
error("no request found")
|
|
end
|
|
|
|
if ngx_phase() ~= "ssl_client_hello" then
|
|
error("API disabled in the current context")
|
|
end
|
|
|
|
local prots = 0
|
|
for _, v in ipairs(protocols) do
|
|
if not prot_map[v] then
|
|
return nil, "invalid protocols failed"
|
|
end
|
|
prots = bor(prots, prot_map[v])
|
|
end
|
|
|
|
local rc = ngx_lua_ffi_ssl_set_protocols(r, prots, errmsg)
|
|
if rc == FFI_OK then
|
|
return true
|
|
end
|
|
|
|
return nil, ffi_str(errmsg[0])
|
|
end
|
|
|
|
return _M
|