From b5a6de0b6282faf6fed8539a04810933d6bb2b05 Mon Sep 17 00:00:00 2001 From: Alexey Date: Wed, 6 Aug 2025 11:31:14 +0300 Subject: [PATCH] add jwt support --- internal/server/sv1/jwt.go | 85 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 internal/server/sv1/jwt.go diff --git a/internal/server/sv1/jwt.go b/internal/server/sv1/jwt.go new file mode 100644 index 0000000..56dde49 --- /dev/null +++ b/internal/server/sv1/jwt.go @@ -0,0 +1,85 @@ +package sv1 + +import ( + "log/slog" + "time" + + "github.com/golang-jwt/jwt" + lua "github.com/yuin/gopher-lua" +) + +func loadJWTMod(llog *slog.Logger, sid string) func(*lua.LState) int { + return func(L *lua.LState) int { + llog.Debug("import module jwt") + jwtMod := L.NewTable() + + L.SetField(jwtMod, "encode", L.NewFunction(jwtEncode)) + L.SetField(jwtMod, "decode", L.NewFunction(jwtDecode)) + + L.SetField(jwtMod, "__gosally_internal", lua.LString(sid)) + L.Push(jwtMod) + return 1 + } +} + +func jwtEncode(L *lua.LState) int { + payloadTbl := L.CheckTable(1) + secret := L.GetField(payloadTbl, "secret").String() + payload := L.GetField(payloadTbl, "payload").(*lua.LTable) + expiresIn := L.GetField(payloadTbl, "expires_in") + expDuration := time.Hour + + if expiresIn.Type() == lua.LTNumber { + floatVal := ConvertLuaTypesToGolang(expiresIn).(float64) + expDuration = time.Duration(floatVal) * time.Second + } + + claims := jwt.MapClaims{} + payload.ForEach(func(key, value lua.LValue) { + claims[key.String()] = ConvertLuaTypesToGolang(value) + }) + claims["exp"] = time.Now().Add(expDuration).Unix() + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, err := token.SignedString([]byte(secret)) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + L.Push(lua.LString(signedToken)) + return 1 +} + +func jwtDecode(L *lua.LState) int { + tokenString := L.CheckString(1) + optsTbl := L.OptTable(2, L.NewTable()) + secret := L.GetField(optsTbl, "secret").String() + + token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) { + return []byte(secret), nil + }) + + if err != nil || !token.Valid { + L.Push(lua.LString("Invalid token: " + err.Error())) + L.Push(lua.LNil) + return 2 + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + L.Push(lua.LString("Invalid claims")) + L.Push(lua.LNil) + return 2 + } + + luaTable := L.NewTable() + for k, v := range claims { + luaTable.RawSetString(k, ConvertGolangTypesToLua(L, v)) + } + + L.Push(lua.LNil) + L.Push(luaTable) + return 2 +}