Compare commits

..

9 Commits

Author SHA1 Message Date
87694f6654 make ConvertGolangTypesToLua simplier 2025-08-06 14:01:45 +03:00
fe628e0f7f develop jwt auth for methods 2025-08-06 14:01:27 +03:00
3898e2833b fmt 2025-08-06 11:32:04 +03:00
e4db8505a0 add sid 2025-08-06 11:31:42 +03:00
0c25d00171 add github.com/golang-jwt/jwt/v5 to the project 2025-08-06 11:31:33 +03:00
b5a6de0b62 add jwt support 2025-08-06 11:31:14 +03:00
1d3d74846e rename database-sqlite to database.sqlite 2025-08-06 10:42:26 +03:00
0141427bfe add print to not allowed functions 2025-08-06 10:04:22 +03:00
866946646b delete some io.* writing functions 2025-08-06 09:58:40 +03:00
10 changed files with 208 additions and 115 deletions

View File

@@ -1,13 +1,14 @@
-- com/DeleteUnit.lua -- com/DeleteUnit.lua
---@diagnostic disable: redefined-local ---@diagnostic disable: redefined-local
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true}) local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
local log = require("internal.log") local log = require("internal.log")
local session = require("internal.session") local session = require("internal.session")
local crypt = require("internal.crypt.bcrypt") local crypt = require("internal.crypt.bcrypt")
local jwt = require("internal.crypt.jwt")
local params = session.request.params.get() local params = session.request.params.get()
local token = session.request.headers.get("x-session-token") local token = session.request.headers.get("authorization")
local function close_db() local function close_db()
if db then if db then
@@ -25,12 +26,32 @@ local function error_response(message, code, data)
close_db() close_db()
end end
if not params then if not token or type(token) ~= "string" then
return error_response("no params provided") return error_response("Access denied")
end end
if not (token and token == require("_config").token()) then local prefix = "Bearer "
return error_response("access denied") if token:sub(1, #prefix) ~= prefix then
return error_response("Invalid Authorization scheme")
end
local access_token = token:sub(#prefix + 1)
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
if err or not data then
session.response.error = {
message = err
}
return
end
if data.session_uuid ~= session.id then
return error_response("Access denied")
end
if not params then
return error_response("no params provided")
end end
if not (params.username and params.email and params.password) then if not (params.username and params.email and params.password) then

View File

@@ -1,13 +1,14 @@
-- com/GetAccess -- com/GetAccess
---@diagnostic disable: redefined-local ---@diagnostic disable: redefined-local
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true}) local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
local log = require("internal.log") local log = require("internal.log")
local session = require("internal.session") local session = require("internal.session")
local crypt = require("internal.crypt.bcrypt") local crypt = require("internal.crypt.bcrypt")
local jwt = require("internal.crypt.jwt")
local params = session.request.params.get() local params = session.request.params.get()
local token = session.request.headers.get("x-session-token") local secret = require("_config").token()
local function close_db() local function close_db()
if db then if db then
@@ -29,10 +30,6 @@ if not params then
return error_response("No params provided") return error_response("No params provided")
end end
if not (token and token == require("_config").token()) then
return error_response("access denied")
end
if not (params.username and params.email and params.password) then if not (params.username and params.email and params.password) then
return error_response("Missing username, email or password") return error_response("Missing username, email or password")
end end
@@ -62,13 +59,14 @@ if not ok then
return error_response("Invalid password") return error_response("Invalid password")
end end
local token = jwt.encode({
secret = secret,
payload = { session_uuid = session.id, admin_user = params.username },
expires_in = 3600
})
session.response.result = { session.response.result = {
user = { access_token = token
id = unit.id,
username = unit.username,
email = unit.email,
created_at = unit.created_at
}
} }
close_db() close_db()

View File

@@ -1,13 +1,14 @@
-- com/PutNewUnit.lua -- com/PutNewUnit.lua
---@diagnostic disable: redefined-local ---@diagnostic disable: redefined-local
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true}) local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
local log = require("internal.log") local log = require("internal.log")
local session = require("internal.session") local session = require("internal.session")
local crypt = require("internal.crypt.bcrypt") local crypt = require("internal.crypt.bcrypt")
local jwt = require("internal.crypt.jwt")
local params = session.request.params.get() local params = session.request.params.get()
local token = session.request.headers.get("x-session-token") local token = session.request.headers.get("authorization")
local function close_db() local function close_db()
if db then if db then
@@ -25,12 +26,32 @@ local function error_response(message, code, data)
close_db() close_db()
end end
if not params then if not token or type(token) ~= "string" then
return error_response("no params provided") return error_response("Access denied")
end end
if not (token and token == require("_config").token()) then local prefix = "Bearer "
return error_response("access denied") if token:sub(1, #prefix) ~= prefix then
return error_response("Invalid Authorization scheme")
end
local access_token = token:sub(#prefix + 1)
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
if err or not data then
session.response.error = {
message = err
}
return
end
if data.session_uuid ~= session.id then
return error_response("Access denied")
end
if not params then
return error_response("no params provided")
end end
if not (params.username and params.email and params.password) then if not (params.username and params.email and params.password) then

View File

@@ -2,15 +2,3 @@
package.path = package.path .. ";/usr/lib64/lua/5.1/?.lua;/usr/local/share/lua/5.1/?.lua" .. ";./com/?.lua;" package.path = package.path .. ";/usr/lib64/lua/5.1/?.lua;/usr/local/share/lua/5.1/?.lua" .. ";./com/?.lua;"
package.cpath = package.cpath .. ";/usr/lib64/lua/5.1/?.so;/usr/local/lib/lua/5.1/?.so" package.cpath = package.cpath .. ";/usr/lib64/lua/5.1/?.so;/usr/local/lib/lua/5.1/?.so"
print = function() end
io.write = function(...) end
io.stdout = function() return nil end
io.stderr = function() return nil end
io.read = function(...) return nil end
---@type table<string, any>
Status = {
ok = "ok",
error = "error",
invalid = "invalid",
}

2
go.mod
View File

@@ -19,6 +19,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect
@@ -40,5 +41,6 @@ require (
require ( require (
github.com/go-chi/cors v1.2.2 github.com/go-chi/cors v1.2.2
github.com/golang-jwt/jwt v3.2.2+incompatible
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

4
go.sum
View File

@@ -13,6 +13,10 @@ github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=

View File

@@ -44,7 +44,7 @@ func getDBMutex(dbPath string) *sync.RWMutex {
return mtx return mtx
} }
func loadDBMod(llog *slog.Logger) func(*lua.LState) int { func loadDBMod(llog *slog.Logger, sid string) func(*lua.LState) int {
return func(L *lua.LState) int { return func(L *lua.LState) int {
llog.Debug("import module db-sqlite") llog.Debug("import module db-sqlite")
dbMod := L.NewTable() dbMod := L.NewTable()
@@ -85,7 +85,7 @@ func loadDBMod(llog *slog.Logger) func(*lua.LState) int {
"close": dbClose, "close": dbClose,
})) }))
L.SetField(dbMod, "__gosally_internal", lua.LString("0")) L.SetField(dbMod, "__gosally_internal", lua.LString(sid))
L.Push(dbMod) L.Push(dbMod)
return 1 return 1
} }

View File

@@ -0,0 +1,85 @@
package sv1
import (
"log/slog"
"time"
"github.com/golang-jwt/jwt"
lua "github.com/yuin/gopher-lua"
)
func loadJWTMod(llog *slog.Logger, sid string) func(*lua.LState) int {
return func(L *lua.LState) int {
llog.Debug("import module jwt")
jwtMod := L.NewTable()
L.SetField(jwtMod, "encode", L.NewFunction(jwtEncode))
L.SetField(jwtMod, "decode", L.NewFunction(jwtDecode))
L.SetField(jwtMod, "__gosally_internal", lua.LString(sid))
L.Push(jwtMod)
return 1
}
}
func jwtEncode(L *lua.LState) int {
payloadTbl := L.CheckTable(1)
secret := L.GetField(payloadTbl, "secret").String()
payload := L.GetField(payloadTbl, "payload").(*lua.LTable)
expiresIn := L.GetField(payloadTbl, "expires_in")
expDuration := time.Hour
if expiresIn.Type() == lua.LTNumber {
floatVal := ConvertLuaTypesToGolang(expiresIn).(float64)
expDuration = time.Duration(floatVal) * time.Second
}
claims := jwt.MapClaims{}
payload.ForEach(func(key, value lua.LValue) {
claims[key.String()] = ConvertLuaTypesToGolang(value)
})
claims["exp"] = time.Now().Add(expDuration).Unix()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(lua.LString(signedToken))
return 1
}
func jwtDecode(L *lua.LState) int {
tokenString := L.CheckString(1)
optsTbl := L.OptTable(2, L.NewTable())
secret := L.GetField(optsTbl, "secret").String()
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
return []byte(secret), nil
})
if err != nil || !token.Valid {
L.Push(lua.LString("Invalid token: " + err.Error()))
L.Push(lua.LNil)
return 2
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
L.Push(lua.LString("Invalid claims"))
L.Push(lua.LNil)
return 2
}
luaTable := L.NewTable()
for k, v := range claims {
luaTable.RawSetString(k, ConvertGolangTypesToLua(L, v))
}
L.Push(lua.LNil)
L.Push(luaTable)
return 2
}

View File

@@ -46,6 +46,19 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
L := lua.NewState() L := lua.NewState()
defer L.Close() defer L.Close()
ioMod := L.GetGlobal("io").(*lua.LTable)
for _, k := range []string{"write", "output", "flush", "read", "input"} {
ioMod.RawSetString(k, lua.LNil)
}
L.Env.RawSetString("print", lua.LNil)
for _, name := range []string{"stdout", "stderr", "stdin"} {
stream := ioMod.RawGetString(name)
if t, ok := stream.(*lua.LUserData); ok {
t.Metatable = lua.LNil
}
}
seed := rand.Int() seed := rand.Int()
loadSessionMod := func(L *lua.LState) int { loadSessionMod := func(L *lua.LState) int {
@@ -59,7 +72,7 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
for k, v := range r.Header { for k, v := range r.Header {
L.SetField(fetchedHeadersTable, k, ConvertGolangTypesToLua(L, v)) L.SetField(fetchedHeadersTable, k, ConvertGolangTypesToLua(L, v))
} }
headersGetter := L.NewFunction(func(L *lua.LState) int { headersGetter := L.NewFunction(func(L *lua.LState) int {
path := L.OptString(1, "") path := L.OptString(1, "")
def := L.Get(2) def := L.Get(2)
@@ -82,14 +95,14 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
} }
return 1 return 1
}) })
fetchedParamsTable := L.NewTable() fetchedParamsTable := L.NewTable()
if fetchedParams, ok := req.Params.(map[string]any); ok { if fetchedParams, ok := req.Params.(map[string]any); ok {
for k, v := range fetchedParams { for k, v := range fetchedParams {
L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v)) L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v))
} }
} }
paramsGetter := L.NewFunction(func(L *lua.LState) int { paramsGetter := L.NewFunction(func(L *lua.LState) int {
path := L.OptString(1, "") path := L.OptString(1, "")
def := L.Get(2) def := L.Get(2)
@@ -106,7 +119,7 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
if tblVal, ok := val.(*lua.LTable); ok { if tblVal, ok := val.(*lua.LTable); ok {
current = tblVal current = tblVal
} else { } else {
if index == size - 1 { if index == size-1 {
return val return val
} }
return lua.LNil return lua.LNil
@@ -114,7 +127,7 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
} }
return lua.LNil return lua.LNil
} }
val := get(fetchedParamsTable, path) val := get(fetchedParamsTable, path)
if val == lua.LNil && def != lua.LNil { if val == lua.LNil && def != lua.LNil {
L.Push(def) L.Push(def)
@@ -399,8 +412,9 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
L.PreloadModule("internal.session", loadSessionMod) L.PreloadModule("internal.session", loadSessionMod)
L.PreloadModule("internal.log", loadLogMod) L.PreloadModule("internal.log", loadLogMod)
L.PreloadModule("internal.net", loadNetMod) L.PreloadModule("internal.net", loadNetMod)
L.PreloadModule("internal.database-sqlite", loadDBMod(llog)) L.PreloadModule("internal.database.sqlite", loadDBMod(llog, fmt.Sprint(seed)))
L.PreloadModule("internal.crypt.bcrypt", loadCryptbcryptMod) L.PreloadModule("internal.crypt.bcrypt", loadCryptbcryptMod)
L.PreloadModule("internal.crypt.jwt", loadJWTMod(llog, fmt.Sprint(seed)))
llog.Debug("preparing environment") llog.Debug("preparing environment")
prep := filepath.Join(*h.x.Config.Conf.Node.ComDir, "_prepare.lua") prep := filepath.Join(*h.x.Config.Conf.Node.ComDir, "_prepare.lua")

View File

@@ -2,6 +2,7 @@ package sv1
import ( import (
"fmt" "fmt"
"reflect"
lua "github.com/yuin/gopher-lua" lua "github.com/yuin/gopher-lua"
) )
@@ -44,85 +45,44 @@ func ConvertLuaTypesToGolang(value lua.LValue) any {
} }
func ConvertGolangTypesToLua(L *lua.LState, val any) lua.LValue { func ConvertGolangTypesToLua(L *lua.LState, val any) lua.LValue {
switch v := val.(type) { if val == nil {
case nil:
return lua.LNil return lua.LNil
}
case string: rv := reflect.ValueOf(val)
return lua.LString(v) rt := rv.Type()
case bool:
return lua.LBool(v)
case int:
return lua.LNumber(v)
case int8:
return lua.LNumber(v)
case int16:
return lua.LNumber(v)
case int32:
return lua.LNumber(v)
case int64:
return lua.LNumber(v)
case uint:
return lua.LNumber(v)
case uint8:
return lua.LNumber(v)
case uint16:
return lua.LNumber(v)
case uint32:
return lua.LNumber(v)
case uint64:
return lua.LNumber(v)
case float32:
return lua.LNumber(v)
case float64:
return lua.LNumber(v)
case []string: switch rt.Kind() {
case reflect.String:
return lua.LString(rv.String())
case reflect.Bool:
return lua.LBool(rv.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return lua.LNumber(rv.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return lua.LNumber(rv.Uint())
case reflect.Float32, reflect.Float64:
return lua.LNumber(rv.Float())
case reflect.Slice, reflect.Array:
tbl := L.NewTable() tbl := L.NewTable()
for i, s := range v { for i := 0; i < rv.Len(); i++ {
tbl.RawSetInt(i+1, lua.LString(s)) tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, rv.Index(i).Interface()))
}
return tbl
case []int:
tbl := L.NewTable()
for i, n := range v {
tbl.RawSetInt(i+1, lua.LNumber(n))
}
return tbl
case []float64:
tbl := L.NewTable()
for i, f := range v {
tbl.RawSetInt(i+1, lua.LNumber(f))
}
return tbl
case []any:
tbl := L.NewTable()
for i, item := range v {
tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, item))
} }
return tbl return tbl
case map[string]string: case reflect.Map:
tbl := L.NewTable() if rt.Key().Kind() == reflect.String {
for k, s := range v { tbl := L.NewTable()
tbl.RawSetString(k, lua.LString(s)) for _, key := range rv.MapKeys() {
val := rv.MapIndex(key)
tbl.RawSetString(key.String(), ConvertGolangTypesToLua(L, val.Interface()))
}
return tbl
} }
return tbl
case map[string]int:
tbl := L.NewTable()
for k, n := range v {
tbl.RawSetString(k, lua.LNumber(n))
}
return tbl
case map[string]any:
tbl := L.NewTable()
for k, val := range v {
tbl.RawSetString(k, ConvertGolangTypesToLua(L, val))
}
return tbl
default: default:
return lua.LString(fmt.Sprintf("%v", v)) return lua.LString(fmt.Sprintf("%v", val))
} }
} return lua.LString(fmt.Sprintf("%v", val))
}