mirror of
https://github.com/akyaiy/GoSally-mvp.git
synced 2026-01-03 20:12:25 +00:00
Compare commits
9 Commits
251e580e8a
...
87694f6654
| Author | SHA1 | Date | |
|---|---|---|---|
| 87694f6654 | |||
| fe628e0f7f | |||
| 3898e2833b | |||
| e4db8505a0 | |||
| 0c25d00171 | |||
| b5a6de0b62 | |||
| 1d3d74846e | |||
| 0141427bfe | |||
| 866946646b |
@@ -1,13 +1,14 @@
|
||||
-- com/DeleteUnit.lua
|
||||
|
||||
---@diagnostic disable: redefined-local
|
||||
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true})
|
||||
local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
|
||||
local log = require("internal.log")
|
||||
local session = require("internal.session")
|
||||
local crypt = require("internal.crypt.bcrypt")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
|
||||
local params = session.request.params.get()
|
||||
local token = session.request.headers.get("x-session-token")
|
||||
local token = session.request.headers.get("authorization")
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
@@ -25,12 +26,32 @@ local function error_response(message, code, data)
|
||||
close_db()
|
||||
end
|
||||
|
||||
if not params then
|
||||
return error_response("no params provided")
|
||||
if not token or type(token) ~= "string" then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
if not (token and token == require("_config").token()) then
|
||||
return error_response("access denied")
|
||||
local prefix = "Bearer "
|
||||
if token:sub(1, #prefix) ~= prefix then
|
||||
return error_response("Invalid Authorization scheme")
|
||||
end
|
||||
|
||||
local access_token = token:sub(#prefix + 1)
|
||||
|
||||
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
|
||||
|
||||
if err or not data then
|
||||
session.response.error = {
|
||||
message = err
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
if data.session_uuid ~= session.id then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
if not params then
|
||||
return error_response("no params provided")
|
||||
end
|
||||
|
||||
if not (params.username and params.email and params.password) then
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
-- com/GetAccess
|
||||
|
||||
---@diagnostic disable: redefined-local
|
||||
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true})
|
||||
local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
|
||||
local log = require("internal.log")
|
||||
local session = require("internal.session")
|
||||
local crypt = require("internal.crypt.bcrypt")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
|
||||
local params = session.request.params.get()
|
||||
local token = session.request.headers.get("x-session-token")
|
||||
local secret = require("_config").token()
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
@@ -29,10 +30,6 @@ if not params then
|
||||
return error_response("No params provided")
|
||||
end
|
||||
|
||||
if not (token and token == require("_config").token()) then
|
||||
return error_response("access denied")
|
||||
end
|
||||
|
||||
if not (params.username and params.email and params.password) then
|
||||
return error_response("Missing username, email or password")
|
||||
end
|
||||
@@ -62,13 +59,14 @@ if not ok then
|
||||
return error_response("Invalid password")
|
||||
end
|
||||
|
||||
local token = jwt.encode({
|
||||
secret = secret,
|
||||
payload = { session_uuid = session.id, admin_user = params.username },
|
||||
expires_in = 3600
|
||||
})
|
||||
|
||||
session.response.result = {
|
||||
user = {
|
||||
id = unit.id,
|
||||
username = unit.username,
|
||||
email = unit.email,
|
||||
created_at = unit.created_at
|
||||
}
|
||||
access_token = token
|
||||
}
|
||||
|
||||
close_db()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
-- com/PutNewUnit.lua
|
||||
|
||||
---@diagnostic disable: redefined-local
|
||||
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true})
|
||||
local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
|
||||
local log = require("internal.log")
|
||||
local session = require("internal.session")
|
||||
local crypt = require("internal.crypt.bcrypt")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
|
||||
local params = session.request.params.get()
|
||||
local token = session.request.headers.get("x-session-token")
|
||||
local token = session.request.headers.get("authorization")
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
@@ -25,12 +26,32 @@ local function error_response(message, code, data)
|
||||
close_db()
|
||||
end
|
||||
|
||||
if not params then
|
||||
return error_response("no params provided")
|
||||
if not token or type(token) ~= "string" then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
if not (token and token == require("_config").token()) then
|
||||
return error_response("access denied")
|
||||
local prefix = "Bearer "
|
||||
if token:sub(1, #prefix) ~= prefix then
|
||||
return error_response("Invalid Authorization scheme")
|
||||
end
|
||||
|
||||
local access_token = token:sub(#prefix + 1)
|
||||
|
||||
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
|
||||
|
||||
if err or not data then
|
||||
session.response.error = {
|
||||
message = err
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
if data.session_uuid ~= session.id then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
if not params then
|
||||
return error_response("no params provided")
|
||||
end
|
||||
|
||||
if not (params.username and params.email and params.password) then
|
||||
|
||||
@@ -2,15 +2,3 @@
|
||||
package.path = package.path .. ";/usr/lib64/lua/5.1/?.lua;/usr/local/share/lua/5.1/?.lua" .. ";./com/?.lua;"
|
||||
package.cpath = package.cpath .. ";/usr/lib64/lua/5.1/?.so;/usr/local/lib/lua/5.1/?.so"
|
||||
|
||||
print = function() end
|
||||
io.write = function(...) end
|
||||
io.stdout = function() return nil end
|
||||
io.stderr = function() return nil end
|
||||
io.read = function(...) return nil end
|
||||
|
||||
---@type table<string, any>
|
||||
Status = {
|
||||
ok = "ok",
|
||||
error = "error",
|
||||
invalid = "invalid",
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -19,6 +19,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
@@ -40,5 +41,6 @@ require (
|
||||
|
||||
require (
|
||||
github.com/go-chi/cors v1.2.2
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -13,6 +13,10 @@ github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
|
||||
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
|
||||
@@ -44,7 +44,7 @@ func getDBMutex(dbPath string) *sync.RWMutex {
|
||||
return mtx
|
||||
}
|
||||
|
||||
func loadDBMod(llog *slog.Logger) func(*lua.LState) int {
|
||||
func loadDBMod(llog *slog.Logger, sid string) func(*lua.LState) int {
|
||||
return func(L *lua.LState) int {
|
||||
llog.Debug("import module db-sqlite")
|
||||
dbMod := L.NewTable()
|
||||
@@ -85,7 +85,7 @@ func loadDBMod(llog *slog.Logger) func(*lua.LState) int {
|
||||
"close": dbClose,
|
||||
}))
|
||||
|
||||
L.SetField(dbMod, "__gosally_internal", lua.LString("0"))
|
||||
L.SetField(dbMod, "__gosally_internal", lua.LString(sid))
|
||||
L.Push(dbMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
85
internal/server/sv1/jwt.go
Normal file
85
internal/server/sv1/jwt.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func loadJWTMod(llog *slog.Logger, sid string) func(*lua.LState) int {
|
||||
return func(L *lua.LState) int {
|
||||
llog.Debug("import module jwt")
|
||||
jwtMod := L.NewTable()
|
||||
|
||||
L.SetField(jwtMod, "encode", L.NewFunction(jwtEncode))
|
||||
L.SetField(jwtMod, "decode", L.NewFunction(jwtDecode))
|
||||
|
||||
L.SetField(jwtMod, "__gosally_internal", lua.LString(sid))
|
||||
L.Push(jwtMod)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func jwtEncode(L *lua.LState) int {
|
||||
payloadTbl := L.CheckTable(1)
|
||||
secret := L.GetField(payloadTbl, "secret").String()
|
||||
payload := L.GetField(payloadTbl, "payload").(*lua.LTable)
|
||||
expiresIn := L.GetField(payloadTbl, "expires_in")
|
||||
expDuration := time.Hour
|
||||
|
||||
if expiresIn.Type() == lua.LTNumber {
|
||||
floatVal := ConvertLuaTypesToGolang(expiresIn).(float64)
|
||||
expDuration = time.Duration(floatVal) * time.Second
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
payload.ForEach(func(key, value lua.LValue) {
|
||||
claims[key.String()] = ConvertLuaTypesToGolang(value)
|
||||
})
|
||||
claims["exp"] = time.Now().Add(expDuration).Unix()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
L.Push(lua.LString(signedToken))
|
||||
return 1
|
||||
}
|
||||
|
||||
func jwtDecode(L *lua.LState) int {
|
||||
tokenString := L.CheckString(1)
|
||||
optsTbl := L.OptTable(2, L.NewTable())
|
||||
secret := L.GetField(optsTbl, "secret").String()
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
L.Push(lua.LString("Invalid token: " + err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
L.Push(lua.LString("Invalid claims"))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
|
||||
luaTable := L.NewTable()
|
||||
for k, v := range claims {
|
||||
luaTable.RawSetString(k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
|
||||
L.Push(lua.LNil)
|
||||
L.Push(luaTable)
|
||||
return 2
|
||||
}
|
||||
@@ -46,6 +46,19 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
L := lua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
ioMod := L.GetGlobal("io").(*lua.LTable)
|
||||
for _, k := range []string{"write", "output", "flush", "read", "input"} {
|
||||
ioMod.RawSetString(k, lua.LNil)
|
||||
}
|
||||
L.Env.RawSetString("print", lua.LNil)
|
||||
|
||||
for _, name := range []string{"stdout", "stderr", "stdin"} {
|
||||
stream := ioMod.RawGetString(name)
|
||||
if t, ok := stream.(*lua.LUserData); ok {
|
||||
t.Metatable = lua.LNil
|
||||
}
|
||||
}
|
||||
|
||||
seed := rand.Int()
|
||||
|
||||
loadSessionMod := func(L *lua.LState) int {
|
||||
@@ -106,7 +119,7 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
if tblVal, ok := val.(*lua.LTable); ok {
|
||||
current = tblVal
|
||||
} else {
|
||||
if index == size - 1 {
|
||||
if index == size-1 {
|
||||
return val
|
||||
}
|
||||
return lua.LNil
|
||||
@@ -399,8 +412,9 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
L.PreloadModule("internal.session", loadSessionMod)
|
||||
L.PreloadModule("internal.log", loadLogMod)
|
||||
L.PreloadModule("internal.net", loadNetMod)
|
||||
L.PreloadModule("internal.database-sqlite", loadDBMod(llog))
|
||||
L.PreloadModule("internal.database.sqlite", loadDBMod(llog, fmt.Sprint(seed)))
|
||||
L.PreloadModule("internal.crypt.bcrypt", loadCryptbcryptMod)
|
||||
L.PreloadModule("internal.crypt.jwt", loadJWTMod(llog, fmt.Sprint(seed)))
|
||||
|
||||
llog.Debug("preparing environment")
|
||||
prep := filepath.Join(*h.x.Config.Conf.Node.ComDir, "_prepare.lua")
|
||||
|
||||
@@ -2,6 +2,7 @@ package sv1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
@@ -44,85 +45,44 @@ func ConvertLuaTypesToGolang(value lua.LValue) any {
|
||||
}
|
||||
|
||||
func ConvertGolangTypesToLua(L *lua.LState, val any) lua.LValue {
|
||||
switch v := val.(type) {
|
||||
|
||||
case nil:
|
||||
if val == nil {
|
||||
return lua.LNil
|
||||
}
|
||||
|
||||
case string:
|
||||
return lua.LString(v)
|
||||
case bool:
|
||||
return lua.LBool(v)
|
||||
case int:
|
||||
return lua.LNumber(v)
|
||||
case int8:
|
||||
return lua.LNumber(v)
|
||||
case int16:
|
||||
return lua.LNumber(v)
|
||||
case int32:
|
||||
return lua.LNumber(v)
|
||||
case int64:
|
||||
return lua.LNumber(v)
|
||||
case uint:
|
||||
return lua.LNumber(v)
|
||||
case uint8:
|
||||
return lua.LNumber(v)
|
||||
case uint16:
|
||||
return lua.LNumber(v)
|
||||
case uint32:
|
||||
return lua.LNumber(v)
|
||||
case uint64:
|
||||
return lua.LNumber(v)
|
||||
case float32:
|
||||
return lua.LNumber(v)
|
||||
case float64:
|
||||
return lua.LNumber(v)
|
||||
rv := reflect.ValueOf(val)
|
||||
rt := rv.Type()
|
||||
|
||||
case []string:
|
||||
switch rt.Kind() {
|
||||
case reflect.String:
|
||||
return lua.LString(rv.String())
|
||||
case reflect.Bool:
|
||||
return lua.LBool(rv.Bool())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return lua.LNumber(rv.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return lua.LNumber(rv.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return lua.LNumber(rv.Float())
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
tbl := L.NewTable()
|
||||
for i, s := range v {
|
||||
tbl.RawSetInt(i+1, lua.LString(s))
|
||||
}
|
||||
return tbl
|
||||
case []int:
|
||||
tbl := L.NewTable()
|
||||
for i, n := range v {
|
||||
tbl.RawSetInt(i+1, lua.LNumber(n))
|
||||
}
|
||||
return tbl
|
||||
case []float64:
|
||||
tbl := L.NewTable()
|
||||
for i, f := range v {
|
||||
tbl.RawSetInt(i+1, lua.LNumber(f))
|
||||
}
|
||||
return tbl
|
||||
case []any:
|
||||
tbl := L.NewTable()
|
||||
for i, item := range v {
|
||||
tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, item))
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, rv.Index(i).Interface()))
|
||||
}
|
||||
return tbl
|
||||
|
||||
case map[string]string:
|
||||
tbl := L.NewTable()
|
||||
for k, s := range v {
|
||||
tbl.RawSetString(k, lua.LString(s))
|
||||
case reflect.Map:
|
||||
if rt.Key().Kind() == reflect.String {
|
||||
tbl := L.NewTable()
|
||||
for _, key := range rv.MapKeys() {
|
||||
val := rv.MapIndex(key)
|
||||
tbl.RawSetString(key.String(), ConvertGolangTypesToLua(L, val.Interface()))
|
||||
}
|
||||
return tbl
|
||||
}
|
||||
return tbl
|
||||
case map[string]int:
|
||||
tbl := L.NewTable()
|
||||
for k, n := range v {
|
||||
tbl.RawSetString(k, lua.LNumber(n))
|
||||
}
|
||||
return tbl
|
||||
case map[string]any:
|
||||
tbl := L.NewTable()
|
||||
for k, val := range v {
|
||||
tbl.RawSetString(k, ConvertGolangTypesToLua(L, val))
|
||||
}
|
||||
return tbl
|
||||
|
||||
default:
|
||||
return lua.LString(fmt.Sprintf("%v", v))
|
||||
return lua.LString(fmt.Sprintf("%v", val))
|
||||
}
|
||||
return lua.LString(fmt.Sprintf("%v", val))
|
||||
}
|
||||
Reference in New Issue
Block a user