diff --git a/src/api/oauth/oauth.lua b/src/api/oauth/oauth.lua index ce7d76b..1892641 100644 --- a/src/api/oauth/oauth.lua +++ b/src/api/oauth/oauth.lua @@ -16,13 +16,13 @@ local routes = { --获取授权码 { paths = { "/yum/v1/oauth/v2/authorize" }, - methods = { "POST" }, + methods = { "GET", "POST" }, handler = oauthService.authorize, }, --根据授权码获取Access-Token { paths = { "/yum/v1/oauth/v2/token" }, - methods = { "POST" }, + methods = { "GET", "POST" }, handler = oauthService.token, }, --通过用户名和密码进行验证 @@ -46,7 +46,7 @@ local routes = { --根据Refresh-Token刷新Access-Token { paths = { "/yum/v1/oauth/v2/refresh" }, - methods = { "POST" }, + methods = { "GET", "POST" }, handler = oauthService.refresh, }, --验证token是否有效 diff --git a/src/dao/oauth/oauth.lua b/src/dao/oauth/oauth.lua index f69ff0c..da120d8 100644 --- a/src/dao/oauth/oauth.lua +++ b/src/dao/oauth/oauth.lua @@ -61,6 +61,7 @@ function _M.getUser(userid) end function _M.getApplicationBy(client_id, redirect_uri) + --print("getApplicationBy client_id:", client_id, " redirect_uri:", redirect_uri) return applicationDao.getApplicationByClientId(client_id, redirect_uri) end diff --git a/src/dao/system/application.lua b/src/dao/system/application.lua index 994d20e..73a6f48 100644 --- a/src/dao/system/application.lua +++ b/src/dao/system/application.lua @@ -106,6 +106,7 @@ end --根据客户端id和重定向地址获取应用程序 function _M.getApplicationByClientId(client_id, redirect_uri) + --print("getApplicationByClientId client_id:", client_id, " redirect_uri:", redirect_uri) return applicationModel:where('app_id', '=', client_id):where('redirect_uris', '=', redirect_uri):get() end diff --git a/src/service/oauth/oauth.lua b/src/service/oauth/oauth.lua index 068bdeb..82088ec 100644 --- a/src/service/oauth/oauth.lua +++ b/src/service/oauth/oauth.lua @@ -4,42 +4,45 @@ --- DateTime: 2025/10/28 11:09 --- 用于 local resp = require("util.response") -local authDao = require("dao.oauth.oauth") +local oauthDao = require("dao.oauth.oauth") local validator = require("validator.oauth.oauth") local cjson = require("cjson.safe") local token = require("util.uuid") local jwt = require "resty.jwt" +local rsa = require("util.rsa") local _M = {} --获取授权码 function _M:authorize() - --读取请求体的数据 - ngx.req.read_body() local args = ngx.req.get_uri_args() - --获取请求数据 - local body_data = ngx.req.get_body_data() - -- 验证json数据是否正确 - local ok, data = pcall(cjson.decode, body_data) - if not ok then - return ngx.exit(ngx.HTTP_BAD_REQUEST) - end - -- 校验客户端请求参数 - ok = validator.validatorAuthorize(data) - --验证失败则返回 - if not ok then - return ngx.exit(ngx.HTTP_BAD_REQUEST) + if ngx.req.get_method() == "POST" then + --读取请求体的数据 + ngx.req.read_body() + --获取请求数据 + local body_data = ngx.req.get_body_data() + -- 验证json数据是否正确 + local ok, data = pcall(cjson.decode, body_data) + if not ok then + return ngx.exit(ngx.HTTP_BAD_REQUEST) + end + -- 校验客户端请求参数 + ok = validator.validatorAuthorize(data) + --验证失败则返回 + if not ok then + return ngx.exit(ngx.HTTP_BAD_REQUEST) + end end -- 校验 response_type 必须为 "code"(授权码模式) if args.response_type ~= "code" then return ngx.exit(ngx.HTTP_BAD_REQUEST) end - -- 校验客户端id和redirect_uri是否存在数据库 - local code, res = authDao.getApplicationBy(args.client_id, args.redirect_uri) + -- 1、校验客户端id和redirect_uri是否存在数据库 + local code, res = oauthDao.getApplicationBy(args.client_id, args.redirect_uri) if code ~= 0 or not res then return ngx.exit(ngx.HTTP_UNAUTHORIZED) end - -- 验证范围 + -- 2、验证范围 if args.scope then local requested_scopes = {} for scope in string.gmatch(args.scope, "%S+") do @@ -47,67 +50,47 @@ function _M:authorize() end -- 验证范围是否允许 todo end - -- 判断用户登录检查 用户已登录,直接展示授权确认页;未登录则重定向到登录页 - local user_logged_in = false + -- 3、判断用户登录检查 用户已登录,直接展示授权确认页;未登录则重定向到登录页 + local user_logged_in = true if not user_logged_in then -- 重定向到登录页,携带当前授权请求参数(登录后跳转回来) - local login_url = "/login?redirect=" .. ngx.escape_uri(ngx_var.request_uri) - ngx.redirect(login_url) + local login_url = "/login?redirect=" .. ngx.escape_uri(ngx.var.request_uri) + --print("authorize login_url:", login_url) + --ngx.redirect(login_url) + local result = resp:json(ngx.HTTP_MOVED_TEMPORARILY, login_url) + resp:send(result) return end - - -- 4. 处理用户授权确认(用户点击"同意"后提交 POST 请求) - if ngx_req.get_method() == "POST" then - local post_args = ngx_req.get_post_args() - if post_args.action == "allow" then - -- 5. 生成授权码(随机字符串,确保唯一性) - local function generate_code() - local str = require "resty.string" - local random = require "resty.random" - local bytes = random.bytes(16) - return str.to_hex(bytes) - end - local code = generate_code() - -- 存储授权码信息(用户ID、客户端ID、scope、生成时间) - local code_key = "auth_code:" .. code - local code_data = cjson.encode({ - user_id = "123456", - client_id = args.client_id, - scope = args.scope or "", - created_at = ngx.time() - }) - local shared_dict = ngx.shared.codeDict - shared_dict:set(code_key, code_data, 5 * 60) - - -- 7. 重定向到客户端回调地址,携带授权码和原始 state(防 CSRF) - local redirect_url = args.redirect_uri .. "?code=" .. code .. "&state=" .. args.state - ngx.redirect(redirect_url) - return - - elseif post_args.action == "deny" then - -- 用户拒绝授权,重定向到客户端并携带错误信息 - local redirect_url = args.redirect_uri .. "?error=access_denied&state=" .. args.state - ngx.redirect(redirect_url) - return - end + -- 4. 生成授权码(随机字符串,确保唯一性) + local function generate_code() + local str = require "resty.string" + local random = require "resty.random" + local bytes = random.bytes(16) + return str.to_hex(bytes) end + local code = generate_code() + print("authorize generate_code:", code) + -- 5、存储授权码信息(用户ID、客户端ID、scope、生成时间) + local code_key = "auth_code:" .. code + local code_data = cjson.encode({ + user_id = "123456", + client_id = args.client_id, + request_uri = ngx.var.request_uri, + scope = args.scope or "", + created_at = ngx.time() + }) + local shared_dict = ngx.shared.codeDict + shared_dict:set(code_key, code_data, 5 * 60) + + -- 6. 重定向到客户端回调地址,携带授权码和原始 state(防 CSRF) + local redirect_url = args.redirect_uri .. "?code=" .. code .. "&state=" .. args.state + local result = resp:json(ngx.HTTP_OK, redirect_url) + resp:send(result) + return end --根据授权码获取Access-Token function _M:token() - --读取请求体的数据 - ngx.req.read_body() - ----获取请求数据 - --local body_data = ngx.req.get_body_data() - ---- 验证数据是否符合json - --local ok = validator.validatorToken(body_data) - ----验证失败则返回 - --if not ok then - -- local result = resp:json(0x000001) - -- resp:send(result) - -- return - --end - -- 1. 解析请求参数(支持 form-data 和 json) local content_type = ngx.req.get_headers()["Content-Type"] or "" local args = {} @@ -121,118 +104,113 @@ function _M:token() -- 默认解析 form-urlencoded args = ngx.req.get_post_args() end - - -- 2. 校验必填参数 - local required = { - grant_type = "refresh_token", - refresh_token = true, - client_id = true - } - if args.grant_type ~= required.grant_type then - return ngx.say(cjson.encode({ - error = "unsupported_grant_type", - error_description = "grant_type must be 'refresh_token'" - })) - end - if not args.refresh_token or not args.client_id then - return ngx.say(cjson.encode({ - error = "invalid_request", - error_description = "missing required parameters" - })) + print("args:", args) + -- 2. 校验必填参数验证数据是否符合json + local ok = validator.validatorToken(args) + --验证失败则返回 + if not ok then + print("validatorToken failed") + local result = resp:json(0x000001) + resp:send(result) + return end - -- 4. 校验 Refresh Token 有效性 - local refresh_token = args.refresh_token - local client_id = args.client_id - - -- 4.1 检查是否在黑名单(已吊销) - local is_revoked, err = red:get("refresh_blacklist:" .. refresh_token) - if is_revoked == "1" then - return ngx.say(cjson.encode({ - error = "token_revoked", - error_description = "refresh_token has been revoked" - })) + -- 4. 校验 code 有效性 + local code = args.code + local code_key = "auth_code:" .. code + local shared_dict = ngx.shared.codeDict + local code_data = shared_dict:get(code_key) + if code_data ~= nil then + -- code 超出时效,需要重新获取code + --local result = resp:json(0x000001) + --resp:send(result) + local login_url = "/login?redirect=" .. ngx.escape_uri(ngx.var.request_uri) + local result = resp:json(ngx.HTTP_MOVED_TEMPORARILY, login_url) + resp:send(result) + return end - - -- 4.2 校验客户端合法性(client_id 与 client_secret 匹配,仅后端客户端需要 secret) - local client_secret = args.client_secret or "" - local stored_secret, err = red:get("client:" .. client_id) - if not stored_secret or stored_secret == ngx.null then - return ngx.say(cjson.encode({ - error = "invalid_client", - error_description = "client_id not found" - })) - end - -- 机密客户端(如后端服务)必须验证 secret - if stored_secret ~= "public" and stored_secret ~= client_secret then - return ngx.say(cjson.encode({ - error = "invalid_client", - error_description = "client_secret invalid" - })) + -- 5、验证redirect_url地址的正确性 + local request_uri = code_data["request_uri"] + if request_uri ~= args.redirect_url then + local login_url = "/login?redirect=" .. ngx.escape_uri(request_uri) + local result = resp:json(ngx.HTTP_MOVED_TEMPORARILY, login_url) + resp:send(result) end + -- 验证成功删除 + shared_dict:delete(code_key) - -- 4.3 验证 Refresh Token 签名(假设 Refresh Token 是 JWT 格式) - local refresh_jwt = jwt:verify(jwt_secret, refresh_token) - if not refresh_jwt.valid then - return ngx.say(cjson.encode({ - error = "invalid_grant", - error_description = "refresh_token invalid: " .. (refresh_jwt.reason or "unknown") - })) + -- 6. 生成密钥对 + local pub_key, priv_key, err = rsa.generate_rsa_keys(2048) + if err then + print("密钥生成失败: ", err) + local result = resp:json(0x00001) + resp:send(result) + return end - - -- 4.4 校验 Token 绑定关系(client_id 必须与 JWT 中一致) - if refresh_jwt.payload.client_id ~= client_id then - return ngx.say(cjson.encode({ - error = "invalid_grant", - error_description = "refresh_token not bound to client_id" - })) - end - - -- 5.1 吊销旧 Refresh Token(加入黑名单,设置与原有效期一致的过期时间) - local ttl = refresh_jwt.payload.exp - ngx.time() - if ttl > 0 then - red:setex("refresh_blacklist:" .. refresh_token, ttl, "1") - end - + local user_id = code_data["user_id"] + local client_id = code_data["client_id"] + local scope = code_data["scope"] local access_token_ttl = 10 * 60 --十分钟 local refresh_token_ttl = 7 * 24 * 3600 --7天 - -- 5.2 生成新 Access Token + -- 7 生成新 Access Token local access_payload = { - sub = refresh_jwt.payload.sub, -- 用户ID + sub = user_id, -- 用户ID client_id = client_id, - scope = refresh_jwt.payload.scope or "", + scope = scope or "", exp = ngx.time() + access_token_ttl, jti = ngx.md5(ngx.time() .. math.random() .. client_id) -- 唯一标识 } - local new_access_token = jwt:sign(jwt_secret, { + local new_access_token = jwt:sign(priv_key, { header = { typ = "JWT", alg = "HS256" }, payload = access_payload }) - -- 5.3 生成新 Refresh Token(滚动刷新) + -- 8 生成新 Refresh Token(滚动刷新) local refresh_payload = { - sub = refresh_jwt.payload.sub, + sub = user_id, client_id = client_id, - scope = refresh_jwt.payload.scope or "", + scope = scope or "", exp = ngx.time() + refresh_token_ttl, jti = ngx.md5(ngx.time() .. math.random() * 1000 .. client_id) } - local new_refresh_token = jwt:sign(jwt_secret, { + local new_refresh_token = jwt:sign(priv_key, { header = { typ = "JWT", alg = "HS256" }, payload = refresh_payload }) - -- 6. 返回结果 - ngx.say(cjson.encode({ - access_token = new_access_token, - token_type = "Bearer", - expires_in = access_token_ttl, - refresh_token = new_refresh_token, - refresh_expires_in = refresh_token_ttl, - scope = access_payload.scope, - issued_at = ngx.time(), - jti = access_payload.jti - })) + -- 9、生存id_token + -- 创建JWT的payload + local payload = { + iss = request_uri, + sub = user_id, + name = user_id, + iat = os.time(), + exp = os.time() + 3600 + } + + -- 使用私钥生成JWT + local jwt_obj = jwt:sign(priv_key, { + header = { + type = "JWT", + alg = "RS256" + }, + payload = payload + }) + if not jwt_obj then + local result = resp:json(0x00001) + resp:send(result) + return + end + --ngx.say("Generated JWT: ", jwt_obj) + -- 10. 返回结果 + local ret = {} + ret.access_token = new_access_token + ret.token_type = "Bearer" + ret.expires_in = access_token_ttl + ret.refresh_token = new_refresh_token + ret.id_token = jwt_obj + local result = resp:json(ngx.HTTP_OK, ret) + resp:send(result) end --用户进行登陆然后验证返回code @@ -260,7 +238,7 @@ function _M:login() resp:send(result) return end - local code, ret = loginDao.login(data) + local code, ret = oauthDao.login(data) --读取数据错误 if code ~= 0 or table.getn(ret) < 0 then local result = resp:json(0x000001) diff --git a/src/test/test.lua b/src/test/test.lua index bc4f2e1..d22ea86 100644 --- a/src/test/test.lua +++ b/src/test/test.lua @@ -92,6 +92,7 @@ local schema = { --} --]] +--[[ ngx.req.read_body() --获取请求数据 local jsonStr = ngx.req.get_body_data() @@ -127,6 +128,7 @@ else end do return end +--]] -- 生成RSA密钥对 local function generate_rsa_keys(length)