diff --git a/internal/server/sv1/handle.go b/internal/server/sv1/handle.go index 9716a9d..fc26307 100644 --- a/internal/server/sv1/handle.go +++ b/internal/server/sv1/handle.go @@ -30,7 +30,7 @@ func (h *HandlerV1) Handle(_ context.Context, sid string, r *http.Request, req * default: // JSON-RPC 2.0 Specification: // https://www.jsonrpc.org/specification#parameter_structures - // + // // "params" MUST be either an *array* or an *object* if included. // Any other type (e.g., a number, string, or boolean) is INVALID. h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInvalidParamsS)) diff --git a/internal/server/sv1/jwt.go b/internal/server/sv1/jwt.go index 4e635d5..0343fc7 100644 --- a/internal/server/sv1/jwt.go +++ b/internal/server/sv1/jwt.go @@ -4,7 +4,7 @@ import ( "log/slog" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" lua "github.com/yuin/gopher-lua" ) diff --git a/internal/server/sv1/lua_handler.go b/internal/server/sv1/lua_handler.go index d05b5b8..da5aa11 100644 --- a/internal/server/sv1/lua_handler.go +++ b/internal/server/sv1/lua_handler.go @@ -43,6 +43,8 @@ func addInitiatorHeaders(sid string, req *http.Request, headers http.Header) { // I will be only glad. // TODO: make this huge function more harmonious by dividing responsibilities func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse { + var __exit = -1 + llog := h.x.SLog.With(slog.String("session-id", sid)) llog.Debug("handling LUA") L := lua.NewState() @@ -50,7 +52,7 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, osMod := L.GetGlobal("os").(*lua.LTable) L.SetField(osMod, "exit", lua.LNil) - + ioMod := L.GetGlobal("io").(*lua.LTable) for _, k := range []string{"write", "output", "flush", "read", "input"} { ioMod.RawSetString(k, lua.LNil) @@ -101,10 +103,13 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, return 1 }) + L.SetField(headersTable, "__fetched", fetchedHeadersTable) + + L.SetField(headersTable, "get", headersGetter) + L.SetField(inTable, "headers", headersTable) + fetchedParamsTable := L.NewTable() switch params := req.Params.(type) { - case nil: - fetchedParamsTable.RawSetInt(1, lua.LNil) case map[string]any: for k, v := range params { L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v)) @@ -113,19 +118,13 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, for i, v := range params { fetchedParamsTable.RawSetInt(i+1, ConvertGolangTypesToLua(L, v)) } - default: - fetchedParamsTable.RawSetInt(1, lua.LNil) } - paramsGetter := L.NewFunction(func(L *lua.LState) int { path := L.OptString(1, "") def := L.Get(2) get := func(tbl *lua.LTable, path string) lua.LValue { - if tbl.RawGetInt(1) == lua.LNil { - return lua.LNil - } if path == "" { return tbl } @@ -146,7 +145,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, return current } - val := get(fetchedParamsTable, path) + paramsTbl := L.GetField(paramsTable, "__fetched") // + val := get(paramsTbl.(*lua.LTable), path) // if val == lua.LNil && def != lua.LNil { L.Push(def) } else { @@ -154,19 +154,95 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, } return 1 }) - L.SetField(headersTable, "__fetched", fetchedHeadersTable) - - L.SetField(headersTable, "get", headersGetter) - L.SetField(inTable, "headers", headersTable) - L.SetField(paramsTable, "__fetched", fetchedParamsTable) L.SetField(paramsTable, "get", paramsGetter) L.SetField(inTable, "params", paramsTable) outTable := L.NewTable() + scriptDataTable := L.NewTable() + L.SetField(outTable, "__script_data", scriptDataTable) L.SetField(inTable, "address", lua.LString(r.RemoteAddr)) + + L.SetField(sessionMod, "throw_error", L.NewFunction(func(L *lua.LState) int { + arg := L.Get(1) + var msg string + switch arg.Type() { + case lua.LTString: + msg = arg.String() + case lua.LTNumber: + msg = strconv.FormatFloat(float64(arg.(lua.LNumber)), 'f', -1, 64) + default: + L.ArgError(1, "expected string or number") + return 0 + } + + L.RaiseError("%s", msg) + return 0 + })) + + resTable := L.NewTable() + L.SetField(scriptDataTable, "result", resTable) + L.SetField(outTable, "send", L.NewFunction(func(L *lua.LState) int { + res := L.Get(1) + if res == lua.LNil { + __exit = 0 + L.RaiseError("__successfull") + return 0 + } + + resFTable := scriptDataTable.RawGetString("result") + + // switch resTable.Type() { + // case lua.LTTable: + if resPTable, ok := res.(*lua.LTable); ok { + resPTable.ForEach(func(key, value lua.LValue) { + L.SetField(resFTable, key.String(), value) + }) + } else { + L.SetField(scriptDataTable, "result", res) + } + // default: + // L.SetField(resTable, key.String(), value) + // } + + __exit = 0 + L.RaiseError("__successfull") + return 0 + })) + + errTable := L.NewTable() + L.SetField(scriptDataTable, "error", errTable) + L.SetField(outTable, "send_error", L.NewFunction(func(L *lua.LState) int { + var params [3]lua.LValue + for i := range 3 { + params[i] = L.Get(i + 1) + } + if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok { + for _, v := range params { + switch v.Type() { + case lua.LTNumber: + if n, ok := v.(lua.LNumber); ok { + L.SetField(errTable, "code", n) + } + case lua.LTString: + if s, ok := v.(lua.LString); ok { + L.SetField(errTable, "message", s) + } + case lua.LTTable: + if tbl, ok := v.(*lua.LTable); ok { + L.SetField(errTable, "data", tbl) + } + } + } + } + + __exit = 1 + L.RaiseError("__unsuccessfull") + return 0 + })) + L.SetField(sessionMod, "request", inTable) L.SetField(sessionMod, "response", outTable) @@ -467,10 +543,11 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, } } llog.Debug("executing script", slog.String("script", path)) - if err := L.DoFile(path); err != nil { + if err := L.DoFile(path); err != nil && __exit == -1 { llog.Error("script error", slog.String("script", path), slog.String("error", err.Error())) return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID) - } + } + pkg := L.GetGlobal("package") pkgTbl, ok := pkg.(*lua.LTable) @@ -505,34 +582,28 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID) } - if errVal := outTbl.RawGetString("error"); errVal != lua.LNil { - llog.Debug("catch error table", slog.String("script", path)) - if errTbl, ok := errVal.(*lua.LTable); ok { - code := rpc.ErrInternalError - message := rpc.ErrInternalErrorS - data := make(map[string]any) - if c := errTbl.RawGetString("code"); c.Type() == lua.LTNumber { - code = int(c.(lua.LNumber)) + if scriptDataTable, ok := outTbl.RawGetString("__script_data").(*lua.LTable); ok { + switch __exit { + case 1: + if errTbl, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok { + llog.Debug("catch error table", slog.String("script", path)) + code := rpc.ErrInternalError + message := rpc.ErrInternalErrorS + if c := errTbl.RawGetString("code"); c.Type() == lua.LTNumber { + code = int(c.(lua.LNumber)) + } + if msg := errTbl.RawGetString("message"); msg.Type() == lua.LTString { + message = msg.String() + } + data := ConvertLuaTypesToGolang(errTbl.RawGetString("data")) + llog.Error("the script terminated with an error", slog.Int("code", code), slog.String("message", message), slog.Any("data", data)) + return rpc.NewError(code, message, data, req.ID) } - if msg := errTbl.RawGetString("message"); msg.Type() == lua.LTString { - message = msg.String() - } - rawData := errTbl.RawGetString("data") - - if tbl, ok := rawData.(*lua.LTable); ok { - tbl.ForEach(func(k, v lua.LValue) { data[k.String()] = ConvertLuaTypesToGolang(v) }) - } else { - llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message)) - return rpc.NewError(code, message, ConvertLuaTypesToGolang(rawData), req.ID) - } - llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message)) - return rpc.NewError(code, message, data, req.ID) + return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID) + case 0: + resVal := ConvertLuaTypesToGolang(scriptDataTable.RawGetString("result")) + return rpc.NewResponse(resVal, req.ID) } - return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID) - } - - if resultVal := outTbl.RawGetString("result"); resultVal != lua.LNil { - return rpc.NewResponse(ConvertLuaTypesToGolang(resultVal), req.ID) } return rpc.NewResponse(nil, req.ID) } diff --git a/internal/server/sv1/lua_types.go b/internal/server/sv1/lua_types.go index 83cfc95..70b5eb3 100644 --- a/internal/server/sv1/lua_types.go +++ b/internal/server/sv1/lua_types.go @@ -3,6 +3,7 @@ package sv1 import ( "fmt" "reflect" + "strconv" lua "github.com/yuin/gopher-lua" ) @@ -18,19 +19,56 @@ func ConvertLuaTypesToGolang(value lua.LValue) any { case lua.LTTable: tbl := value.(*lua.LTable) - var arr []any + maxIdx := 0 isArray := true - tbl.ForEach(func(key, val lua.LValue) { - if key.Type() != lua.LTNumber { - isArray = false + + var isNumeric = false + tbl.ForEach(func(key, _ lua.LValue) { + var numKey lua.LValue + var ok bool + switch key.Type() { + case lua.LTString: + numKey, ok = key.(lua.LString) + if !ok { + isArray = false + return + } + case lua.LTNumber: + numKey, ok = key.(lua.LNumber) + if !ok { + isArray = false + return + } + isNumeric = true + } + + num, err := strconv.Atoi(numKey.String()) + if err != nil { + isArray = false + return + } + if num < 1 { + isArray = false + return + } + if num > maxIdx { + maxIdx = num } - arr = append(arr, ConvertLuaTypesToGolang(val)) }) if isArray { + arr := make([]any, maxIdx) + if isNumeric { + for i := 1; i <= maxIdx; i++ { + arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetInt(i)) + } + } else { + for i := 1; i <= maxIdx; i++ { + arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetString(strconv.Itoa(i))) + } + } return arr } - result := make(map[string]any) tbl.ForEach(func(key, val lua.LValue) { result[key.String()] = ConvertLuaTypesToGolang(val)