diff --git a/internal/server/gateway/route.go b/internal/server/gateway/route.go index bff4069..4f91e45 100644 --- a/internal/server/gateway/route.go +++ b/internal/server/gateway/route.go @@ -5,16 +5,17 @@ import ( "io" "log/slog" "net/http" - "net/http/httptest" + "sync" "github.com/akyaiy/GoSally-mvp/internal/server/rpc" ) func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) { - var req rpc.RPCRequest + w.Header().Set("Content-Type", "application/json") body, err := io.ReadAll(r.Body) if err != nil { - rpc.WriteRouterError(w, http.StatusBadRequest, &rpc.RPCError{ + w.WriteHeader(http.StatusBadRequest) + rpc.WriteError(w, &rpc.RPCResponse{ JSONRPC: rpc.JSONRPCVersion, ID: nil, Error: map[string]any{ @@ -26,55 +27,70 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) { return } - if err := json.Unmarshal(body, &req); err != nil { - rpc.WriteRouterError(w, http.StatusBadRequest, &rpc.RPCError{ - JSONRPC: rpc.JSONRPCVersion, - ID: nil, - Error: map[string]any{ - "code": rpc.ErrParseError, - "message": rpc.ErrParseErrorS, - }, - }) - gs.log.Info("invalid request received", slog.String("issue", rpc.ErrParseErrorS)) + // determine if the JSON-RPC request is a batch + var batch []rpc.RPCRequest + json.Unmarshal(body, &batch) + var single rpc.RPCRequest + if batch == nil { + if err := json.Unmarshal(body, &single); err != nil { + w.WriteHeader(http.StatusBadRequest) + rpc.WriteError(w, &rpc.RPCResponse{ + JSONRPC: rpc.JSONRPCVersion, + ID: nil, + Error: map[string]any{ + "code": rpc.ErrParseError, + "message": rpc.ErrParseErrorS, + }, + }) + gs.log.Info("invalid request received", slog.String("issue", rpc.ErrParseErrorS)) + return + } + resp := gs.Route(r, &single) + rpc.WriteResponse(w, resp) return } - if req.JSONRPC != rpc.JSONRPCVersion { - rpc.WriteRouterError(w, http.StatusBadRequest, &rpc.RPCError{ - JSONRPC: rpc.JSONRPCVersion, - ID: req.ID, - Error: map[string]any{ - "code": rpc.ErrInvalidRequest, - "message": rpc.ErrInvalidRequestS, - }, - }) - gs.log.Info("invalid request received", slog.String("issue", rpc.ErrInvalidRequestS), slog.String("requested-version", req.JSONRPC)) - return + // handle batch + responses := make(chan rpc.RPCResponse, len(batch)) + var wg sync.WaitGroup + for _, m := range batch { + wg.Add(1) + go func(req rpc.RPCRequest) { + defer wg.Done() + res := gs.Route(r, &req) + if res != nil { + responses <- *res + } + }(m) } + wg.Wait() + close(responses) - gs.Route(w, r, req) + var result []rpc.RPCResponse + for res := range responses { + result = append(result, res) + } + if len(result) > 0 { + json.NewEncoder(w).Encode(result) + } } -func (gs *GatewayServer) Route(w http.ResponseWriter, r *http.Request, req rpc.RPCRequest) { - server, ok := gs.servers[serversApiVer(req.Params.ContextVersion)] - if !ok { - rpc.WriteRouterError(w, http.StatusBadRequest, &rpc.RPCError{ - JSONRPC: rpc.JSONRPCVersion, - ID: req.ID, - Error: map[string]any{ - "code": rpc.ErrContextVersion, - "message": rpc.ErrContextVersionS, - }, - }) - gs.log.Info("invalid request received", slog.String("issue", rpc.ErrContextVersionS), slog.String("requested-version", req.Params.ContextVersion)) - return +func (gs *GatewayServer) Route(r *http.Request, req *rpc.RPCRequest) *rpc.RPCResponse { + if req.JSONRPC != rpc.JSONRPCVersion { + gs.log.Info("invalid request received", slog.String("issue", rpc.ErrInvalidRequestS), slog.String("requested-version", req.JSONRPC)) + return rpc.NewError(rpc.ErrInvalidRequest, rpc.ErrInvalidRequestS, req.ID) } + server, ok := gs.servers[serversApiVer(req.ContextVersion)] + if !ok { + gs.log.Info("invalid request received", slog.String("issue", rpc.ErrContextVersionS), slog.String("requested-version", req.ContextVersion)) + return rpc.NewError(rpc.ErrContextVersion, rpc.ErrContextVersionS, req.ID) + } + + resp := server.Handle(r, req) // checks if request is notification if req.ID == nil { - rr := httptest.NewRecorder() - server.Handle(rr, r, req) - return + return nil } - server.Handle(w, r, req) + return resp }