diff --git a/lib/resty/libjwt/init.lua b/lib/resty/libjwt/init.lua index a86ef72..6a7f172 100644 --- a/lib/resty/libjwt/init.lua +++ b/lib/resty/libjwt/init.lua @@ -19,7 +19,7 @@ end local TOKEN_VALID = 0 local JWKS_CACHE_TTL = 300 -local function _validate(params) +local function _authenticate(params) local headers = ngx.req.get_headers() local token, err @@ -75,13 +75,21 @@ local function _validate(params) return nil, "invalid token" end -local function _response_error(error_message, return_unauthorized_default) +local function _extract_claims(token, params) + for _, claim in ipairs(params.extract_claims) do + if token.claim[claim] ~= nil then + ngx.var["jwt_"..claim] = token.claim[claim] + end + end +end + +local function _response_error(error_message, return_unauthorized_default, status) if return_unauthorized_default == true then ngx.header.content_type = "application/json; charset=utf-8" local response = { message = error_message } - ngx.status = ngx.HTTP_UNAUTHORIZED + ngx.status = status ngx.say(cjson.encode(response)) ngx.exit(ngx.status) end @@ -92,14 +100,21 @@ end function _M.validate(user_params) local params, err = utils.get_params(user_params) if params == nil then - return nil, _response_error(err, true) + return nil, _response_error(err, true, ngx.HTTP_UNAUTHORIZED) end local parsed_token - parsed_token, err = _validate(params) + parsed_token, err = _authenticate(params) if err ~= "" then - return nil, _response_error(err, params.return_unauthorized_default) + return nil, _response_error(err, params.return_unauthorized_default, ngx.HTTP_UNAUTHORIZED) end + + local claims_extracted; + claims_extracted, err = pcall(_extract_claims, parsed_token, params) + if not claims_extracted then + return nil, _response_error(err, params.return_unauthorized_default, ngx.HTTP_INTERNAL_SERVER_ERROR) + end + return parsed_token, "" end diff --git a/lib/resty/libjwt/utils.lua b/lib/resty/libjwt/utils.lua index ac267a7..5f34ecb 100644 --- a/lib/resty/libjwt/utils.lua +++ b/lib/resty/libjwt/utils.lua @@ -4,7 +4,8 @@ function _M.get_params(params) local result = { header_token = "Authorization", jwks_files = {}, - return_unauthorized_default = true + return_unauthorized_default = true, + extract_claims = {}, } if params == nil then return nil, "params is required" @@ -19,6 +20,14 @@ function _M.get_params(params) if params["return_unauthorized_default"] ~= nil then result.return_unauthorized_default = params["return_unauthorized_default"] end + + if params["extract_claims"] ~= nil then + if type(params["extract_claims"]) ~= "table" then + return nil, "extract_claims is not an array" + end + result.extract_claims = params["extract_claims"] + end + if type(params["jwks_files"]) ~= "table" then return nil, "jwks_files is not an array" end diff --git a/nginx.conf b/nginx.conf index 00dda4a..1b603dc 100644 --- a/nginx.conf +++ b/nginx.conf @@ -5,10 +5,15 @@ events { } http { + log_format mylog '$remote_addr - "$request"\tStatus: $status JWT-Subject: $jwt_sub JWT-Email: $jwt_email'; + access_log /dev/stdout mylog; server { listen 8888; server_name localhost; + set $jwt_sub ""; + set $jwt_email ""; + location /public { default_type application/json; return 200 '{"message": "Hello, World!"}'; @@ -18,11 +23,12 @@ http { access_by_lua_block { local libjwt = require("resty.libjwt") local cjson = require("cjson.safe") - local token, err = libjwt.validate({ - ["jwks_files"] = {"/usr/share/tokens/jwks.json"}, + local token = libjwt.validate({ + jwks_files = {"/usr/share/tokens/jwks.json"}, + extract_claims = {"sub", "email"}, }) if token then - local claim_str = cjson.encode(claim) or "Invalid Token" + local claim_str = cjson.encode(token) or "Invalid Token" ngx.status = ngx.HTTP_OK return ngx.say(claim_str) end diff --git a/test/params_test.lua b/test/params_test.lua index 86a7a69..e39d60c 100644 --- a/test/params_test.lua +++ b/test/params_test.lua @@ -40,7 +40,8 @@ function TestShouldReturnValidatedParams() lu.assertEquals(result, { header_token = "token", jwks_files = { "files" }, - return_unauthorized_default = true + return_unauthorized_default = true, + extract_claims = {}, }) end