some small changes, and add send, send_error, throw_error and some field

This commit is contained in:
2025-08-09 10:41:50 +03:00
parent 811403a0a2
commit 2ceb236a53
4 changed files with 160 additions and 51 deletions

View File

@@ -4,7 +4,7 @@ import (
"log/slog" "log/slog"
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt/v5"
lua "github.com/yuin/gopher-lua" lua "github.com/yuin/gopher-lua"
) )

View File

@@ -43,6 +43,8 @@ func addInitiatorHeaders(sid string, req *http.Request, headers http.Header) {
// I will be only glad. // I will be only glad.
// TODO: make this huge function more harmonious by dividing responsibilities // TODO: make this huge function more harmonious by dividing responsibilities
func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse { func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse {
var __exit = -1
llog := h.x.SLog.With(slog.String("session-id", sid)) llog := h.x.SLog.With(slog.String("session-id", sid))
llog.Debug("handling LUA") llog.Debug("handling LUA")
L := lua.NewState() L := lua.NewState()
@@ -101,10 +103,13 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
return 1 return 1
}) })
L.SetField(headersTable, "__fetched", fetchedHeadersTable)
L.SetField(headersTable, "get", headersGetter)
L.SetField(inTable, "headers", headersTable)
fetchedParamsTable := L.NewTable() fetchedParamsTable := L.NewTable()
switch params := req.Params.(type) { switch params := req.Params.(type) {
case nil:
fetchedParamsTable.RawSetInt(1, lua.LNil)
case map[string]any: case map[string]any:
for k, v := range params { for k, v := range params {
L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v)) L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v))
@@ -113,19 +118,13 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
for i, v := range params { for i, v := range params {
fetchedParamsTable.RawSetInt(i+1, ConvertGolangTypesToLua(L, v)) fetchedParamsTable.RawSetInt(i+1, ConvertGolangTypesToLua(L, v))
} }
default:
fetchedParamsTable.RawSetInt(1, lua.LNil)
} }
paramsGetter := L.NewFunction(func(L *lua.LState) int { paramsGetter := L.NewFunction(func(L *lua.LState) int {
path := L.OptString(1, "") path := L.OptString(1, "")
def := L.Get(2) def := L.Get(2)
get := func(tbl *lua.LTable, path string) lua.LValue { get := func(tbl *lua.LTable, path string) lua.LValue {
if tbl.RawGetInt(1) == lua.LNil {
return lua.LNil
}
if path == "" { if path == "" {
return tbl return tbl
} }
@@ -146,7 +145,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
return current return current
} }
val := get(fetchedParamsTable, path) paramsTbl := L.GetField(paramsTable, "__fetched") //
val := get(paramsTbl.(*lua.LTable), path) //
if val == lua.LNil && def != lua.LNil { if val == lua.LNil && def != lua.LNil {
L.Push(def) L.Push(def)
} else { } else {
@@ -154,19 +154,95 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
} }
return 1 return 1
}) })
L.SetField(headersTable, "__fetched", fetchedHeadersTable)
L.SetField(headersTable, "get", headersGetter)
L.SetField(inTable, "headers", headersTable)
L.SetField(paramsTable, "__fetched", fetchedParamsTable) L.SetField(paramsTable, "__fetched", fetchedParamsTable)
L.SetField(paramsTable, "get", paramsGetter) L.SetField(paramsTable, "get", paramsGetter)
L.SetField(inTable, "params", paramsTable) L.SetField(inTable, "params", paramsTable)
outTable := L.NewTable() outTable := L.NewTable()
scriptDataTable := L.NewTable()
L.SetField(outTable, "__script_data", scriptDataTable)
L.SetField(inTable, "address", lua.LString(r.RemoteAddr)) L.SetField(inTable, "address", lua.LString(r.RemoteAddr))
L.SetField(sessionMod, "throw_error", L.NewFunction(func(L *lua.LState) int {
arg := L.Get(1)
var msg string
switch arg.Type() {
case lua.LTString:
msg = arg.String()
case lua.LTNumber:
msg = strconv.FormatFloat(float64(arg.(lua.LNumber)), 'f', -1, 64)
default:
L.ArgError(1, "expected string or number")
return 0
}
L.RaiseError("%s", msg)
return 0
}))
resTable := L.NewTable()
L.SetField(scriptDataTable, "result", resTable)
L.SetField(outTable, "send", L.NewFunction(func(L *lua.LState) int {
res := L.Get(1)
if res == lua.LNil {
__exit = 0
L.RaiseError("__successfull")
return 0
}
resFTable := scriptDataTable.RawGetString("result")
// switch resTable.Type() {
// case lua.LTTable:
if resPTable, ok := res.(*lua.LTable); ok {
resPTable.ForEach(func(key, value lua.LValue) {
L.SetField(resFTable, key.String(), value)
})
} else {
L.SetField(scriptDataTable, "result", res)
}
// default:
// L.SetField(resTable, key.String(), value)
// }
__exit = 0
L.RaiseError("__successfull")
return 0
}))
errTable := L.NewTable()
L.SetField(scriptDataTable, "error", errTable)
L.SetField(outTable, "send_error", L.NewFunction(func(L *lua.LState) int {
var params [3]lua.LValue
for i := range 3 {
params[i] = L.Get(i + 1)
}
if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
for _, v := range params {
switch v.Type() {
case lua.LTNumber:
if n, ok := v.(lua.LNumber); ok {
L.SetField(errTable, "code", n)
}
case lua.LTString:
if s, ok := v.(lua.LString); ok {
L.SetField(errTable, "message", s)
}
case lua.LTTable:
if tbl, ok := v.(*lua.LTable); ok {
L.SetField(errTable, "data", tbl)
}
}
}
}
__exit = 1
L.RaiseError("__unsuccessfull")
return 0
}))
L.SetField(sessionMod, "request", inTable) L.SetField(sessionMod, "request", inTable)
L.SetField(sessionMod, "response", outTable) L.SetField(sessionMod, "response", outTable)
@@ -467,11 +543,12 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
} }
} }
llog.Debug("executing script", slog.String("script", path)) llog.Debug("executing script", slog.String("script", path))
if err := L.DoFile(path); err != nil { if err := L.DoFile(path); err != nil && __exit == -1 {
llog.Error("script error", slog.String("script", path), slog.String("error", err.Error())) llog.Error("script error", slog.String("script", path), slog.String("error", err.Error()))
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID) return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
} }
pkg := L.GetGlobal("package") pkg := L.GetGlobal("package")
pkgTbl, ok := pkg.(*lua.LTable) pkgTbl, ok := pkg.(*lua.LTable)
if !ok { if !ok {
@@ -505,34 +582,28 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID) return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
} }
if errVal := outTbl.RawGetString("error"); errVal != lua.LNil { if scriptDataTable, ok := outTbl.RawGetString("__script_data").(*lua.LTable); ok {
llog.Debug("catch error table", slog.String("script", path)) switch __exit {
if errTbl, ok := errVal.(*lua.LTable); ok { case 1:
code := rpc.ErrInternalError if errTbl, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
message := rpc.ErrInternalErrorS llog.Debug("catch error table", slog.String("script", path))
data := make(map[string]any) code := rpc.ErrInternalError
if c := errTbl.RawGetString("code"); c.Type() == lua.LTNumber { message := rpc.ErrInternalErrorS
code = int(c.(lua.LNumber)) if c := errTbl.RawGetString("code"); c.Type() == lua.LTNumber {
code = int(c.(lua.LNumber))
}
if msg := errTbl.RawGetString("message"); msg.Type() == lua.LTString {
message = msg.String()
}
data := ConvertLuaTypesToGolang(errTbl.RawGetString("data"))
llog.Error("the script terminated with an error", slog.Int("code", code), slog.String("message", message), slog.Any("data", data))
return rpc.NewError(code, message, data, req.ID)
} }
if msg := errTbl.RawGetString("message"); msg.Type() == lua.LTString { return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
message = msg.String() case 0:
} resVal := ConvertLuaTypesToGolang(scriptDataTable.RawGetString("result"))
rawData := errTbl.RawGetString("data") return rpc.NewResponse(resVal, req.ID)
if tbl, ok := rawData.(*lua.LTable); ok {
tbl.ForEach(func(k, v lua.LValue) { data[k.String()] = ConvertLuaTypesToGolang(v) })
} else {
llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message))
return rpc.NewError(code, message, ConvertLuaTypesToGolang(rawData), req.ID)
}
llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message))
return rpc.NewError(code, message, data, req.ID)
} }
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
}
if resultVal := outTbl.RawGetString("result"); resultVal != lua.LNil {
return rpc.NewResponse(ConvertLuaTypesToGolang(resultVal), req.ID)
} }
return rpc.NewResponse(nil, req.ID) return rpc.NewResponse(nil, req.ID)
} }

View File

@@ -3,6 +3,7 @@ package sv1
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strconv"
lua "github.com/yuin/gopher-lua" lua "github.com/yuin/gopher-lua"
) )
@@ -18,19 +19,56 @@ func ConvertLuaTypesToGolang(value lua.LValue) any {
case lua.LTTable: case lua.LTTable:
tbl := value.(*lua.LTable) tbl := value.(*lua.LTable)
var arr []any maxIdx := 0
isArray := true isArray := true
tbl.ForEach(func(key, val lua.LValue) {
if key.Type() != lua.LTNumber { var isNumeric = false
isArray = false tbl.ForEach(func(key, _ lua.LValue) {
var numKey lua.LValue
var ok bool
switch key.Type() {
case lua.LTString:
numKey, ok = key.(lua.LString)
if !ok {
isArray = false
return
}
case lua.LTNumber:
numKey, ok = key.(lua.LNumber)
if !ok {
isArray = false
return
}
isNumeric = true
}
num, err := strconv.Atoi(numKey.String())
if err != nil {
isArray = false
return
}
if num < 1 {
isArray = false
return
}
if num > maxIdx {
maxIdx = num
} }
arr = append(arr, ConvertLuaTypesToGolang(val))
}) })
if isArray { if isArray {
arr := make([]any, maxIdx)
if isNumeric {
for i := 1; i <= maxIdx; i++ {
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetInt(i))
}
} else {
for i := 1; i <= maxIdx; i++ {
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetString(strconv.Itoa(i)))
}
}
return arr return arr
} }
result := make(map[string]any) result := make(map[string]any)
tbl.ForEach(func(key, val lua.LValue) { tbl.ForEach(func(key, val lua.LValue) {
result[key.String()] = ConvertLuaTypesToGolang(val) result[key.String()] = ConvertLuaTypesToGolang(val)