diff --git a/hooks/run.go b/hooks/run.go index 2bd331a..cff646b 100644 --- a/hooks/run.go +++ b/hooks/run.go @@ -19,6 +19,7 @@ import ( "github.com/akyaiy/GoSally-mvp/internal/engine/config" "github.com/akyaiy/GoSally-mvp/internal/engine/logs" "github.com/akyaiy/GoSally-mvp/internal/server/gateway" + "github.com/akyaiy/GoSally-mvp/internal/server/session" "github.com/akyaiy/GoSally-mvp/internal/server/sv1" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" @@ -62,7 +63,10 @@ func RunHook(ctx context.Context, cs *corestate.CoreState, x *app.AppX) error { Ver: "v1", }) + session_manager := session.New(*x.Config.Conf.HTTPServer.SessionTTL) + s := gateway.InitGateway(&gateway.GatewayServerInit{ + SM: session_manager, CS: cs, X: x, }, serverv1) @@ -137,6 +141,8 @@ func RunHook(ctx context.Context, cs *corestate.CoreState, x *app.AppX) error { } }() + session_manager.StartCleanup(5 * time.Minute) + if *x.Config.Conf.Updates.UpdatesEnabled { go func() { defer utils.CatchPanicWithCancel(cancelMain) diff --git a/internal/engine/config/compositor.go b/internal/engine/config/compositor.go index ba9c188..227dee1 100644 --- a/internal/engine/config/compositor.go +++ b/internal/engine/config/compositor.go @@ -48,6 +48,7 @@ func (c *Compositor) LoadConf(path string) error { v.SetDefault("node.com_dir", "./com/") v.SetDefault("http_server.address", "0.0.0.0") v.SetDefault("http_server.port", "8080") + v.SetDefault("http_server.session_ttl", "30m") v.SetDefault("http_server.timeout", "5s") v.SetDefault("http_server.idle_timeout", "60s") v.SetDefault("tls.enabled", false) diff --git a/internal/engine/config/config.go b/internal/engine/config/config.go index 3419aa1..fb2924e 100644 --- a/internal/engine/config/config.go +++ b/internal/engine/config/config.go @@ -35,6 +35,7 @@ type Node struct { type HTTPServer struct { Address *string `mapstructure:"address"` Port *string `mapstructure:"port"` + SessionTTL *time.Duration `mapstructure:"session_ttl"` Timeout *time.Duration `mapstructure:"timeout"` IdleTimeout *time.Duration `mapstructure:"idle_timeout"` } diff --git a/internal/server/gateway/general_types.go b/internal/server/gateway/general_types.go index 9472e60..15e50a4 100644 --- a/internal/server/gateway/general_types.go +++ b/internal/server/gateway/general_types.go @@ -6,6 +6,7 @@ import ( "github.com/akyaiy/GoSally-mvp/internal/core/corestate" "github.com/akyaiy/GoSally-mvp/internal/engine/app" "github.com/akyaiy/GoSally-mvp/internal/server/rpc" + "github.com/akyaiy/GoSally-mvp/internal/server/session" ) // serversApiVer is a type alias for string, used to represent API version strings in the GeneralServer. @@ -13,7 +14,7 @@ type serversApiVer string type ServerApiContract interface { GetVersion() string - Handle(r *http.Request, req *rpc.RPCRequest) *rpc.RPCResponse + Handle(sid string, r *http.Request, req *rpc.RPCRequest) *rpc.RPCResponse } // GeneralServer implements the GeneralServerApiContract and serves as a router for different API versions. @@ -22,6 +23,7 @@ type GatewayServer struct { // The key is the version string, and the value is the server implementing GeneralServerApi servers map[serversApiVer]ServerApiContract + sm *session.SessionManager cs *corestate.CoreState x *app.AppX } diff --git a/internal/server/gateway/init.go b/internal/server/gateway/init.go index e28713e..75005b4 100644 --- a/internal/server/gateway/init.go +++ b/internal/server/gateway/init.go @@ -5,10 +5,12 @@ import ( "github.com/akyaiy/GoSally-mvp/internal/core/corestate" "github.com/akyaiy/GoSally-mvp/internal/engine/app" + "github.com/akyaiy/GoSally-mvp/internal/server/session" ) // GeneralServerInit structure only for initialization general server. type GatewayServerInit struct { + SM *session.SessionManager CS *corestate.CoreState X *app.AppX } @@ -17,6 +19,7 @@ type GatewayServerInit struct { func InitGateway(o *GatewayServerInit, servers ...ServerApiContract) *GatewayServer { general := &GatewayServer{ servers: make(map[serversApiVer]ServerApiContract), + sm: o.SM, cs: o.CS, x: o.X, } diff --git a/internal/server/gateway/route.go b/internal/server/gateway/route.go index 68739d3..4ec923b 100644 --- a/internal/server/gateway/route.go +++ b/internal/server/gateway/route.go @@ -9,10 +9,28 @@ import ( "github.com/akyaiy/GoSally-mvp/internal/core/utils" "github.com/akyaiy/GoSally-mvp/internal/server/rpc" + "github.com/google/uuid" ) func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + sessionUUID := r.Header.Get("X-Session-UUID") + if sessionUUID == "" { + sessionUUID = uuid.New().String() + } + + w.Header().Set("X-Session-UUID", sessionUUID) + if !gs.sm.Add(sessionUUID) { + rpc.WriteError(w, &rpc.RPCResponse{ + Error: map[string]any{ + "code": rpc.ErrSessionIsTaken, + "message": rpc.ErrSessionIsTakenS, + }, + }) + return + } + defer gs.sm.Delete(sessionUUID) + body, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -46,7 +64,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) { gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrParseErrorS)) return } - resp := gs.Route(r, &single) + resp := gs.Route(sessionUUID, r, &single) rpc.WriteResponse(w, resp) return } @@ -58,7 +76,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) { wg.Add(1) go func(req rpc.RPCRequest) { defer wg.Done() - res := gs.Route(r, &req) + res := gs.Route(sessionUUID, r, &req) if res != nil { responses <- *res } @@ -76,7 +94,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) { } } -func (gs *GatewayServer) Route(r *http.Request, req *rpc.RPCRequest) (resp *rpc.RPCResponse) { +func (gs *GatewayServer) Route(sid string, r *http.Request, req *rpc.RPCRequest) (resp *rpc.RPCResponse) { defer utils.CatchPanicWithFallback(func(rec any) { gs.x.SLog.Error("panic caught in handler", slog.Any("error", rec)) resp = rpc.NewError(rpc.ErrInternalError, "Internal server error (panic)", req.ID) @@ -92,7 +110,7 @@ func (gs *GatewayServer) Route(r *http.Request, req *rpc.RPCRequest) (resp *rpc. return rpc.NewError(rpc.ErrContextVersion, rpc.ErrContextVersionS, req.ID) } - resp = server.Handle(r, req) + resp = server.Handle(sid, r, req) // checks if request is notification if req.ID == nil { return nil diff --git a/internal/server/rpc/errors.go b/internal/server/rpc/errors.go index 5ab1e4b..13150d7 100644 --- a/internal/server/rpc/errors.go +++ b/internal/server/rpc/errors.go @@ -24,4 +24,7 @@ const ( ErrMethodIsMissing = -32020 ErrMethodIsMissingS = "Method is missing" + + ErrSessionIsTaken = -32030 + ErrSessionIsTakenS = "The session is already taken" ) diff --git a/internal/server/session/manager.go b/internal/server/session/manager.go new file mode 100644 index 0000000..9ad3ec3 --- /dev/null +++ b/internal/server/session/manager.go @@ -0,0 +1,47 @@ +package session + +import ( + "sync" + "time" +) + +type SessionManagerContract interface { + Add(uuid string) bool + Delete(uuid string) + StartCleanup(interval time.Duration) +} + +type SessionManager struct { + sessions sync.Map + ttl time.Duration +} + +func New(ttl time.Duration) *SessionManager { + return &SessionManager{ + ttl: ttl, + } +} + +func (sm *SessionManager) Add(uuid string) bool { + _, loaded := sm.sessions.LoadOrStore(uuid, time.Now().Add(sm.ttl)) + return !loaded +} + +func (sm *SessionManager) Delete(uuid string) { + sm.sessions.Delete(uuid) +} + +func (sm *SessionManager) StartCleanup(interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + for range ticker.C { + sm.sessions.Range(func(key, value any) bool { + expiry := value.(time.Time) + if time.Now().After(expiry) { + sm.sessions.Delete(key) + } + return true + }) + } + }() +} diff --git a/internal/server/sv1/handle.go b/internal/server/sv1/handle.go index 6794820..60bbacf 100644 --- a/internal/server/sv1/handle.go +++ b/internal/server/sv1/handle.go @@ -7,7 +7,7 @@ import ( "github.com/akyaiy/GoSally-mvp/internal/server/rpc" ) -func (h *HandlerV1) Handle(r *http.Request, req *rpc.RPCRequest) *rpc.RPCResponse { +func (h *HandlerV1) Handle(sid string, r *http.Request, req *rpc.RPCRequest) *rpc.RPCResponse { if req.Method == "" { h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrMethodNotFoundS), slog.String("requested-method", req.Method)) return rpc.NewError(rpc.ErrMethodIsMissing, rpc.ErrMethodIsMissingS, req.ID) @@ -24,5 +24,5 @@ func (h *HandlerV1) Handle(r *http.Request, req *rpc.RPCRequest) *rpc.RPCRespons } } - return h.handleLUA(r, req, method) + return h.handleLUA(sid, r, req, method) } diff --git a/internal/server/sv1/lua_handler.go b/internal/server/sv1/lua_handler.go index 6e2772d..b156340 100644 --- a/internal/server/sv1/lua_handler.go +++ b/internal/server/sv1/lua_handler.go @@ -15,13 +15,13 @@ import ( lua "github.com/yuin/gopher-lua" ) -func addInitiatorHeaders(req *http.Request, headers http.Header) { +func addInitiatorHeaders(sid string, req *http.Request, headers http.Header) { clientIP := req.RemoteAddr if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" { clientIP = forwardedFor } headers.Set("X-Initiator-IP", clientIP) - + headers.Set("X-Session-UUID", sid) headers.Set("X-Initiator-Host", req.Host) headers.Set("X-Initiator-User-Agent", req.UserAgent()) headers.Set("X-Initiator-Referer", req.Referer()) @@ -34,7 +34,7 @@ func addInitiatorHeaders(req *http.Request, headers http.Header) { // contribute to the development of the code, // I will be only glad. // TODO: make this huge function more harmonious by dividing responsibilities -func (h *HandlerV1) handleLUA(r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse { +func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse { L := lua.NewState() defer L.Close() @@ -104,7 +104,7 @@ func (h *HandlerV1) handleLUA(r *http.Request, req *rpc.RPCRequest, path string) return 2 } - addInitiatorHeaders(r, req.Header) + addInitiatorHeaders(sid, r, req.Header) client := &http.Client{} resp, err := client.Do(req) @@ -165,7 +165,7 @@ func (h *HandlerV1) handleLUA(r *http.Request, req *rpc.RPCRequest, path string) req.Header.Set("Content-Type", contentType) - addInitiatorHeaders(r, req.Header) + addInitiatorHeaders(sid, r, req.Header) client := &http.Client{} resp, err := client.Do(req)