diff --git a/internal/server/sv1/handle.go b/internal/server/sv1/handle.go index a81d284..3cb9424 100644 --- a/internal/server/sv1/handle.go +++ b/internal/server/sv1/handle.go @@ -24,6 +24,16 @@ func (h *HandlerV1) Handle(_ context.Context, sid string, r *http.Request, req * return rpc.NewError(rpc.ErrMethodNotFound, rpc.ErrMethodNotFoundS, nil, req.ID) } } - - return h.handleLUA(sid, r, req, method) + switch req.Params.(type) { + case map[string]any, []any, nil: + return h.handleLUA(sid, r, req, method) + default: + // JSON-RPC 2.0 Specification: + // https://www.jsonrpc.org/specification#parameter_structures + // + // "params" MUST be either an *array* or an *object* if included. + // Any other type (e.g., a number, string, null, or boolean) is INVALID. + h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInvalidParamsS)) + return rpc.NewError(rpc.ErrInvalidParams, rpc.ErrInvalidParamsS, nil, req.ID) + } } diff --git a/internal/server/sv1/lua_handler.go b/internal/server/sv1/lua_handler.go index be4b6ae..2cc3a7e 100644 --- a/internal/server/sv1/lua_handler.go +++ b/internal/server/sv1/lua_handler.go @@ -100,6 +100,9 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, fetchedParamsTable := L.NewTable() switch params := req.Params.(type) { + case nil: + print(1) + fetchedParamsTable.RawSetInt(1, lua.LNil) case map[string]any: for k, v := range params { L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v)) @@ -108,6 +111,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, for i, v := range params { fetchedParamsTable.RawSetInt(i+1, ConvertGolangTypesToLua(L, v)) } + default: + fetchedParamsTable.RawSetInt(1, lua.LNil) } @@ -116,6 +121,9 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, def := L.Get(2) get := func(tbl *lua.LTable, path string) lua.LValue { + if tbl.RawGetInt(1) == lua.LNil { + return lua.LNil + } if path == "" { return tbl } @@ -521,9 +529,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID) } - resultVal := outTbl.RawGetString("result") - if resultVal != lua.LNil { - return rpc.NewResponse(ConvertLuaTypesToGolang(resultVal), req.ID) + if resultVal := outTbl.RawGetString("result"); resultVal != lua.LNil { + return rpc.NewResponse(ConvertLuaTypesToGolang(resultVal), req.ID) } return rpc.NewResponse(nil, req.ID) }