add headers lua runtime support

This commit is contained in:
2025-08-05 23:15:13 +03:00
parent c734779b69
commit 251e580e8a
6 changed files with 115 additions and 34 deletions

View File

@@ -6,6 +6,9 @@ local log = require("internal.log")
local session = require("internal.session") local session = require("internal.session")
local crypt = require("internal.crypt.bcrypt") local crypt = require("internal.crypt.bcrypt")
local params = session.request.params.get()
local token = session.request.headers.get("x-session-token")
local function close_db() local function close_db()
if db then if db then
db:close() db:close()
@@ -22,20 +25,14 @@ local function error_response(message, code, data)
close_db() close_db()
end end
if not session.request.params then if not params then
return error_response("no params provided") return error_response("no params provided")
end end
if not session.request.params.token then if not (token and token == require("_config").token()) then
return error_response("access denied") return error_response("access denied")
end end
if session.request.params.token ~= require("_config").token() then
return error_response("access denied")
end
local params = session.request.params
if not (params.username and params.email and params.password) then if not (params.username and params.email and params.password) then
return error_response("no username/email/password provided") return error_response("no username/email/password provided")
end end

View File

@@ -6,6 +6,9 @@ local log = require("internal.log")
local session = require("internal.session") local session = require("internal.session")
local crypt = require("internal.crypt.bcrypt") local crypt = require("internal.crypt.bcrypt")
local params = session.request.params.get()
local token = session.request.headers.get("x-session-token")
local function close_db() local function close_db()
if db then if db then
db:close() db:close()
@@ -22,16 +25,11 @@ local function error_response(message, code, data)
close_db() close_db()
end end
local params = session.request.params
if not params then if not params then
return error_response("No params provided") return error_response("No params provided")
end end
if not session.request.params.token then if not (token and token == require("_config").token()) then
return error_response("access denied")
end
if session.request.params.token ~= require("_config").token() then
return error_response("access denied") return error_response("access denied")
end end

View File

@@ -6,6 +6,9 @@ local log = require("internal.log")
local session = require("internal.session") local session = require("internal.session")
local crypt = require("internal.crypt.bcrypt") local crypt = require("internal.crypt.bcrypt")
local params = session.request.params.get()
local token = session.request.headers.get("x-session-token")
local function close_db() local function close_db()
if db then if db then
db:close() db:close()
@@ -22,19 +25,14 @@ local function error_response(message, code, data)
close_db() close_db()
end end
if not session.request.params then if not params then
return error_response("no params provided") return error_response("no params provided")
end end
if not session.request.params.token then if not (token and token == require("_config").token()) then
return error_response("access denied") return error_response("access denied")
end end
if session.request.params.token ~= require("_config").token() then
return error_response("access denied")
end
local params = session.request.params
if not (params.username and params.email and params.password) then if not (params.username and params.email and params.password) then
return error_response("no username/email/password provided") return error_response("no username/email/password provided")
end end

View File

@@ -2,7 +2,9 @@
local session = require("internal.session") local session = require("internal.session")
if session.request.params.about then local params = session.request.params.get()
if params.about then
session.response.result = { session.response.result = {
description = "Returns a list of available methods", description = "Returns a list of available methods",
params = { params = {
@@ -48,7 +50,7 @@ local function scanDirectory(basePath, targetPath)
end end
local basePath = "com" local basePath = "com"
local layer = session.request and session.request.params.layer and session.request.params.layer:gsub(">", "/") or nil local layer = params.layer and params.layer:gsub(">", "/") or nil
session.response.result = { session.response.result = {
answer = layer and scanDirectory(basePath, layer) or scanDirectory(basePath, "") answer = layer and scanDirectory(basePath, layer) or scanDirectory(basePath, "")

View File

@@ -53,7 +53,36 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
sessionMod := L.NewTable() sessionMod := L.NewTable()
inTable := L.NewTable() inTable := L.NewTable()
paramsTable := L.NewTable() paramsTable := L.NewTable()
headersTable := L.NewTable()
fetchedHeadersTable := L.NewTable()
for k, v := range r.Header {
L.SetField(fetchedHeadersTable, k, ConvertGolangTypesToLua(L, v))
}
headersGetter := L.NewFunction(func(L *lua.LState) int {
path := L.OptString(1, "")
def := L.Get(2)
get := func(path string) lua.LValue {
if path == "" {
return fetchedHeadersTable
}
fetched := r.Header.Get(path)
if fetched == "" {
return lua.LNil
}
return lua.LString(fetched)
}
val := get(path)
if val == lua.LNil && def != lua.LNil {
L.Push(def)
} else {
L.Push(val)
}
return 1
})
fetchedParamsTable := L.NewTable() fetchedParamsTable := L.NewTable()
if fetchedParams, ok := req.Params.(map[string]any); ok { if fetchedParams, ok := req.Params.(map[string]any); ok {
for k, v := range fetchedParams { for k, v := range fetchedParams {
@@ -61,7 +90,7 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
} }
} }
getter := L.NewFunction(func(L *lua.LState) int { paramsGetter := L.NewFunction(func(L *lua.LState) int {
path := L.OptString(1, "") path := L.OptString(1, "")
def := L.Get(2) def := L.Get(2)
@@ -94,8 +123,14 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
} }
return 1 return 1
}) })
L.SetField(headersTable, "__fetched", fetchedHeadersTable)
L.SetField(paramsTable, "get", getter) L.SetField(headersTable, "get", headersGetter)
L.SetField(inTable, "headers", headersTable)
L.SetField(paramsTable, "__fetched", fetchedParamsTable)
L.SetField(paramsTable, "get", paramsGetter)
L.SetField(inTable, "params", paramsTable) L.SetField(inTable, "params", paramsTable)
outTable := L.NewTable() outTable := L.NewTable()

View File

@@ -45,33 +45,84 @@ func ConvertLuaTypesToGolang(value lua.LValue) any {
func ConvertGolangTypesToLua(L *lua.LState, val any) lua.LValue { func ConvertGolangTypesToLua(L *lua.LState, val any) lua.LValue {
switch v := val.(type) { switch v := val.(type) {
case nil:
return lua.LNil
case string: case string:
return lua.LString(v) return lua.LString(v)
case bool: case bool:
return lua.LBool(v) return lua.LBool(v)
case int: case int:
return lua.LNumber(float64(v)) return lua.LNumber(v)
case int8:
return lua.LNumber(v)
case int16:
return lua.LNumber(v)
case int32:
return lua.LNumber(v)
case int64: case int64:
return lua.LNumber(float64(v)) return lua.LNumber(v)
case uint:
return lua.LNumber(v)
case uint8:
return lua.LNumber(v)
case uint16:
return lua.LNumber(v)
case uint32:
return lua.LNumber(v)
case uint64:
return lua.LNumber(v)
case float32: case float32:
return lua.LNumber(float64(v)) return lua.LNumber(v)
case float64: case float64:
return lua.LNumber(v) return lua.LNumber(v)
case []string:
tbl := L.NewTable()
for i, s := range v {
tbl.RawSetInt(i+1, lua.LString(s))
}
return tbl
case []int:
tbl := L.NewTable()
for i, n := range v {
tbl.RawSetInt(i+1, lua.LNumber(n))
}
return tbl
case []float64:
tbl := L.NewTable()
for i, f := range v {
tbl.RawSetInt(i+1, lua.LNumber(f))
}
return tbl
case []any: case []any:
tbl := L.NewTable() tbl := L.NewTable()
for i, item := range v { for i, item := range v {
tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, item)) tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, item))
} }
return tbl return tbl
case map[string]any:
case map[string]string:
tbl := L.NewTable() tbl := L.NewTable()
for key, value := range v { for k, s := range v {
tbl.RawSetString(key, ConvertGolangTypesToLua(L, value)) tbl.RawSetString(k, lua.LString(s))
} }
return tbl return tbl
case nil: case map[string]int:
return lua.LNil tbl := L.NewTable()
for k, n := range v {
tbl.RawSetString(k, lua.LNumber(n))
}
return tbl
case map[string]any:
tbl := L.NewTable()
for k, val := range v {
tbl.RawSetString(k, ConvertGolangTypesToLua(L, val))
}
return tbl
default: default:
return lua.LString(fmt.Sprintf("%v", v)) return lua.LString(fmt.Sprintf("%v", v))
} }
} }