mirror of
https://github.com/akyaiy/GoSally-mvp.git
synced 2026-01-07 23:32:24 +00:00
move go files to src/
This commit is contained in:
30
src/internal/server/gateway/general_types.go
Normal file
30
src/internal/server/gateway/general_types.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/engine/app"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/rpc"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/session"
|
||||
)
|
||||
|
||||
// serversApiVer is a type alias for string, used to represent API version strings in the GeneralServer.
|
||||
type serversApiVer string
|
||||
|
||||
type ServerApiContract interface {
|
||||
GetVersion() string
|
||||
Handle(ctx context.Context, sid string, r *http.Request, req *rpc.RPCRequest) *rpc.RPCResponse
|
||||
}
|
||||
|
||||
// GeneralServer implements the GeneralServerApiContract and serves as a router for different API versions.
|
||||
type GatewayServer struct {
|
||||
// servers holds the registered servers by their API version.
|
||||
// The key is the version string, and the value is the server implementing GeneralServerApi
|
||||
servers map[serversApiVer]ServerApiContract
|
||||
|
||||
sm *session.SessionManager
|
||||
cs *corestate.CoreState
|
||||
x *app.AppX
|
||||
}
|
||||
47
src/internal/server/gateway/init.go
Normal file
47
src/internal/server/gateway/init.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/engine/app"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/session"
|
||||
)
|
||||
|
||||
// GeneralServerInit structure only for initialization general server.
|
||||
type GatewayServerInit struct {
|
||||
SM *session.SessionManager
|
||||
CS *corestate.CoreState
|
||||
X *app.AppX
|
||||
}
|
||||
|
||||
// InitGeneral initializes a new GeneralServer with the provided configuration and registered servers.
|
||||
func InitGateway(o *GatewayServerInit, servers ...ServerApiContract) *GatewayServer {
|
||||
general := &GatewayServer{
|
||||
servers: make(map[serversApiVer]ServerApiContract),
|
||||
sm: o.SM,
|
||||
cs: o.CS,
|
||||
x: o.X,
|
||||
}
|
||||
|
||||
// register the provided servers
|
||||
// s is each server implementing GeneralServerApiContract, this is not a general server
|
||||
for _, s := range servers {
|
||||
general.servers[serversApiVer(s.GetVersion())] = s
|
||||
}
|
||||
return general
|
||||
}
|
||||
|
||||
// GetVersion returns the API version of the GeneralServer, which is "general".
|
||||
func (s *GatewayServer) GetVersion() string {
|
||||
return "general"
|
||||
}
|
||||
|
||||
// AppendToArray adds a new server to the GeneralServer's internal map.
|
||||
func (s *GatewayServer) AppendToArray(server ServerApiContract) error {
|
||||
if _, exist := s.servers[serversApiVer(server.GetVersion())]; !exist {
|
||||
s.servers[serversApiVer(server.GetVersion())] = server
|
||||
return nil
|
||||
}
|
||||
return errors.New("server with this version is already exist")
|
||||
}
|
||||
114
src/internal/server/gateway/route.go
Normal file
114
src/internal/server/gateway/route.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/utils"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/rpc"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context() // TODO
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
sessionUUID := r.Header.Get("X-Session-UUID")
|
||||
if sessionUUID == "" {
|
||||
sessionUUID = uuid.New().String()
|
||||
|
||||
}
|
||||
gs.x.SLog.Debug("new request", slog.String("session-uuid", sessionUUID), slog.Group("connection", slog.String("ip", r.RemoteAddr)))
|
||||
|
||||
w.Header().Set("X-Session-UUID", sessionUUID)
|
||||
if !gs.sm.Add(sessionUUID) {
|
||||
gs.x.SLog.Debug("session is busy", slog.String("session-uuid", sessionUUID))
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrSessionIsBusy, rpc.ErrSessionIsBusyS, nil, nil))
|
||||
return
|
||||
}
|
||||
defer gs.sm.Delete(sessionUUID)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
gs.x.SLog.Debug("failed to read body", slog.String("err", err.Error()))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, nil))
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInternalErrorS))
|
||||
return
|
||||
}
|
||||
|
||||
// determine if the JSON-RPC request is a batch
|
||||
var batch []rpc.RPCRequest
|
||||
json.Unmarshal(body, &batch)
|
||||
var single rpc.RPCRequest
|
||||
if batch == nil {
|
||||
if err := json.Unmarshal(body, &single); err != nil {
|
||||
gs.x.SLog.Debug("failed to parse json", slog.String("err", err.Error()))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrParseError, rpc.ErrParseErrorS, nil, nil))
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrParseErrorS))
|
||||
return
|
||||
}
|
||||
resp := gs.Route(ctx, sessionUUID, r, &single)
|
||||
if resp == nil {
|
||||
w.Write([]byte(""))
|
||||
return
|
||||
}
|
||||
rpc.WriteResponse(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// handle batch
|
||||
responses := make(chan rpc.RPCResponse, len(batch))
|
||||
var wg sync.WaitGroup
|
||||
for _, m := range batch {
|
||||
wg.Add(1)
|
||||
go func(req rpc.RPCRequest) {
|
||||
defer wg.Done()
|
||||
res := gs.Route(ctx, sessionUUID, r, &req)
|
||||
if res != nil {
|
||||
responses <- *res
|
||||
}
|
||||
}(m)
|
||||
}
|
||||
wg.Wait()
|
||||
close(responses)
|
||||
|
||||
var result []rpc.RPCResponse
|
||||
for res := range responses {
|
||||
result = append(result, res)
|
||||
}
|
||||
if len(result) > 0 {
|
||||
json.NewEncoder(w).Encode(result)
|
||||
} else {
|
||||
w.Write([]byte("[]"))
|
||||
}
|
||||
}
|
||||
|
||||
func (gs *GatewayServer) Route(ctx context.Context, sid string, r *http.Request, req *rpc.RPCRequest) (resp *rpc.RPCResponse) {
|
||||
defer utils.CatchPanicWithFallback(func(rec any) {
|
||||
gs.x.SLog.Error("panic caught in handler", slog.Any("error", rec))
|
||||
resp = rpc.NewError(rpc.ErrInternalError, "Internal server error (panic)", nil, req.ID)
|
||||
})
|
||||
if req.JSONRPC != rpc.JSONRPCVersion {
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInvalidRequestS), slog.String("requested-version", req.JSONRPC))
|
||||
return rpc.NewError(rpc.ErrInvalidRequest, rpc.ErrInvalidRequestS, nil, req.ID)
|
||||
}
|
||||
|
||||
server, ok := gs.servers[serversApiVer(req.ContextVersion)]
|
||||
if !ok {
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrContextVersionS), slog.String("requested-version", req.ContextVersion))
|
||||
return rpc.NewError(rpc.ErrContextVersion, rpc.ErrContextVersionS, nil, req.ID)
|
||||
}
|
||||
|
||||
// checks if request is notification
|
||||
if req.ID == nil {
|
||||
go server.Handle(ctx, sid, r, req)
|
||||
return nil
|
||||
}
|
||||
return server.Handle(ctx, sid, r, req)
|
||||
}
|
||||
30
src/internal/server/rpc/definition.go
Normal file
30
src/internal/server/rpc/definition.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package rpc
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type RPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID *json.RawMessage `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
ContextVersion string `json:"context-version,omitempty"`
|
||||
}
|
||||
|
||||
type RPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID *json.RawMessage `json:"id"`
|
||||
Result any `json:"result,omitzero"`
|
||||
Error any `json:"error,omitzero"`
|
||||
Data *RPCData `json:"data,omitzero"`
|
||||
}
|
||||
|
||||
type RPCData struct {
|
||||
ResponsibleNode string `json:"responsible-node,omitempty"`
|
||||
Salt string `json:"salt,omitempty"`
|
||||
Checksum string `json:"checksum-md5,omitempty"`
|
||||
NewSessionUUID string `json:"new-session-uuid,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
JSONRPCVersion = "2.0"
|
||||
)
|
||||
30
src/internal/server/rpc/errors.go
Normal file
30
src/internal/server/rpc/errors.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package rpc
|
||||
|
||||
const (
|
||||
ErrParseError = -32700
|
||||
ErrParseErrorS = "Parse error"
|
||||
|
||||
ErrInvalidRequest = -32600
|
||||
ErrInvalidRequestS = "Invalid Request"
|
||||
|
||||
ErrMethodNotFound = -32601
|
||||
ErrMethodNotFoundS = "Method not found"
|
||||
|
||||
ErrInvalidParams = -32602
|
||||
ErrInvalidParamsS = "Invalid params"
|
||||
|
||||
ErrInternalError = -32603
|
||||
ErrInternalErrorS = "Internal error"
|
||||
|
||||
ErrContextVersion = -32010
|
||||
ErrContextVersionS = "Invalid context version"
|
||||
|
||||
ErrInvalidMethodFormat = -32020
|
||||
ErrInvalidMethodFormatS = "Invalid method format"
|
||||
|
||||
ErrMethodIsMissing = -32020
|
||||
ErrMethodIsMissingS = "Method is missing"
|
||||
|
||||
ErrSessionIsBusy = -32030
|
||||
ErrSessionIsBusyS = "The session is busy"
|
||||
)
|
||||
60
src/internal/server/rpc/responsers.go
Normal file
60
src/internal/server/rpc/responsers.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func generateChecksum(result any) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x", md5.Sum(data))
|
||||
}
|
||||
|
||||
func generateSalt() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
func GetData(data any) *RPCData {
|
||||
return &RPCData{
|
||||
Salt: generateSalt(),
|
||||
ResponsibleNode: corestate.NODE_UUID,
|
||||
Checksum: generateChecksum(data),
|
||||
}
|
||||
}
|
||||
|
||||
func NewError(code int, message string, data any, id *json.RawMessage) *RPCResponse {
|
||||
Error := make(map[string]any)
|
||||
Error = map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
if data != nil {
|
||||
Error["data"] = data
|
||||
}
|
||||
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Error: Error,
|
||||
Data: GetData(Error),
|
||||
}
|
||||
}
|
||||
|
||||
func NewResponse(result any, id *json.RawMessage) *RPCResponse {
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Result: result,
|
||||
Data: GetData(result),
|
||||
}
|
||||
}
|
||||
23
src/internal/server/rpc/writers.go
Normal file
23
src/internal/server/rpc/writers.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func write(w http.ResponseWriter, msg *RPCResponse) error {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func WriteError(w http.ResponseWriter, errm *RPCResponse) error {
|
||||
return write(w, errm)
|
||||
}
|
||||
|
||||
func WriteResponse(w http.ResponseWriter, response *RPCResponse) error {
|
||||
return write(w, response)
|
||||
}
|
||||
47
src/internal/server/session/manager.go
Normal file
47
src/internal/server/session/manager.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionManagerContract interface {
|
||||
Add(uuid string) bool
|
||||
Delete(uuid string)
|
||||
StartCleanup(interval time.Duration)
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
sessions sync.Map
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func New(ttl time.Duration) *SessionManager {
|
||||
return &SessionManager{
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Add(uuid string) bool {
|
||||
_, loaded := sm.sessions.LoadOrStore(uuid, time.Now().Add(sm.ttl))
|
||||
return !loaded
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Delete(uuid string) {
|
||||
sm.sessions.Delete(uuid)
|
||||
}
|
||||
|
||||
func (sm *SessionManager) StartCleanup(interval time.Duration) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
for range ticker.C {
|
||||
sm.sessions.Range(func(key, value any) bool {
|
||||
expiry := value.(time.Time)
|
||||
if time.Now().After(expiry) {
|
||||
sm.sessions.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
415
src/internal/server/sv1/db_sqlite.go
Normal file
415
src/internal/server/sv1/db_sqlite.go
Normal file
@@ -0,0 +1,415 @@
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
type DBConnection struct {
|
||||
dbPath string
|
||||
log bool
|
||||
logger *slog.Logger
|
||||
writeChan chan *dbWriteRequest
|
||||
closeChan chan struct{}
|
||||
}
|
||||
|
||||
type dbWriteRequest struct {
|
||||
query string
|
||||
args []interface{}
|
||||
resCh chan *dbWriteResult
|
||||
}
|
||||
|
||||
type dbWriteResult struct {
|
||||
rowsAffected int64
|
||||
err error
|
||||
}
|
||||
|
||||
var dbMutexMap = make(map[string]*sync.RWMutex)
|
||||
var dbGlobalMutex sync.Mutex
|
||||
|
||||
func getDBMutex(dbPath string) *sync.RWMutex {
|
||||
dbGlobalMutex.Lock()
|
||||
defer dbGlobalMutex.Unlock()
|
||||
|
||||
if mtx, ok := dbMutexMap[dbPath]; ok {
|
||||
return mtx
|
||||
}
|
||||
|
||||
mtx := &sync.RWMutex{}
|
||||
dbMutexMap[dbPath] = mtx
|
||||
return mtx
|
||||
}
|
||||
|
||||
func loadDBMod(llog *slog.Logger, sid string) func(*lua.LState) int {
|
||||
return func(L *lua.LState) int {
|
||||
llog.Debug("import module db-sqlite")
|
||||
dbMod := L.NewTable()
|
||||
|
||||
L.SetField(dbMod, "connect", L.NewFunction(func(L *lua.LState) int {
|
||||
dbPath := L.CheckString(1)
|
||||
|
||||
logQueries := false
|
||||
if L.GetTop() >= 2 {
|
||||
opts := L.CheckTable(2)
|
||||
if val := opts.RawGetString("log"); val != lua.LNil {
|
||||
logQueries = lua.LVAsBool(val)
|
||||
}
|
||||
}
|
||||
|
||||
conn := &DBConnection{
|
||||
dbPath: dbPath,
|
||||
log: logQueries,
|
||||
logger: llog,
|
||||
writeChan: make(chan *dbWriteRequest, 100),
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go conn.processWrites()
|
||||
|
||||
ud := L.NewUserData()
|
||||
ud.Value = conn
|
||||
L.SetMetatable(ud, L.GetTypeMetatable("gosally_db"))
|
||||
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}))
|
||||
|
||||
mt := L.NewTypeMetatable("gosally_db")
|
||||
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
|
||||
"exec": dbExec,
|
||||
"query": dbQuery,
|
||||
"query_row": dbQueryRow,
|
||||
"close": dbClose,
|
||||
}))
|
||||
|
||||
L.SetField(dbMod, "__seed", lua.LString(sid))
|
||||
L.Push(dbMod)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *DBConnection) processWrites() {
|
||||
for {
|
||||
select {
|
||||
case req := <-conn.writeChan:
|
||||
mtx := getDBMutex(conn.dbPath)
|
||||
mtx.Lock()
|
||||
|
||||
db, err := sql.Open("sqlite", conn.dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_sync=NORMAL&_cache_size=-10000")
|
||||
if err == nil {
|
||||
_, err = db.Exec("PRAGMA journal_mode=WAL;")
|
||||
if err == nil {
|
||||
res, execErr := db.Exec(req.query, req.args...)
|
||||
if execErr == nil {
|
||||
rows, _ := res.RowsAffected()
|
||||
req.resCh <- &dbWriteResult{rowsAffected: rows}
|
||||
} else {
|
||||
req.resCh <- &dbWriteResult{err: execErr}
|
||||
}
|
||||
}
|
||||
db.Close()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
req.resCh <- &dbWriteResult{err: err}
|
||||
}
|
||||
|
||||
mtx.Unlock()
|
||||
case <-conn.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dbExec(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
if !ok {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("invalid database connection"))
|
||||
return 2
|
||||
}
|
||||
|
||||
query := L.CheckString(2)
|
||||
|
||||
var args []any
|
||||
if L.GetTop() >= 3 {
|
||||
params := L.CheckTable(3)
|
||||
params.ForEach(func(k lua.LValue, v lua.LValue) {
|
||||
args = append(args, ConvertLuaTypesToGolang(v))
|
||||
})
|
||||
}
|
||||
|
||||
if conn.log {
|
||||
conn.logger.Info("DB Exec",
|
||||
slog.String("query", query),
|
||||
slog.Any("params", args))
|
||||
}
|
||||
|
||||
resCh := make(chan *dbWriteResult, 1)
|
||||
conn.writeChan <- &dbWriteRequest{
|
||||
query: query,
|
||||
args: args,
|
||||
resCh: resCh,
|
||||
}
|
||||
|
||||
ctx := L.NewTable()
|
||||
L.SetField(ctx, "done", lua.LBool(false))
|
||||
|
||||
var result lua.LValue = lua.LNil
|
||||
var errorMsg lua.LValue = lua.LNil
|
||||
|
||||
L.SetField(ctx, "wait", L.NewFunction(func(L *lua.LState) int {
|
||||
res := <-resCh
|
||||
L.SetField(ctx, "done", lua.LBool(true))
|
||||
|
||||
if res.err != nil {
|
||||
errorMsg = lua.LString(res.err.Error())
|
||||
result = lua.LNil
|
||||
} else {
|
||||
result = lua.LNumber(res.rowsAffected)
|
||||
errorMsg = lua.LNil
|
||||
}
|
||||
|
||||
if res.err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(res.err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(lua.LNumber(res.rowsAffected))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}))
|
||||
|
||||
L.SetField(ctx, "check", L.NewFunction(func(L *lua.LState) int {
|
||||
select {
|
||||
case res := <-resCh:
|
||||
L.SetField(ctx, "done", lua.LBool(true))
|
||||
if res.err != nil {
|
||||
errorMsg = lua.LString(res.err.Error())
|
||||
result = lua.LNil
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(res.err.Error()))
|
||||
return 2
|
||||
} else {
|
||||
result = lua.LNumber(res.rowsAffected)
|
||||
errorMsg = lua.LNil
|
||||
L.Push(lua.LNumber(res.rowsAffected))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
default:
|
||||
L.Push(result)
|
||||
L.Push(errorMsg)
|
||||
return 2
|
||||
}
|
||||
}))
|
||||
|
||||
L.Push(ctx)
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
|
||||
func dbQueryRow(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
if !ok {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("invalid database connection"))
|
||||
return 2
|
||||
}
|
||||
|
||||
query := L.CheckString(2)
|
||||
|
||||
var args []any
|
||||
if L.GetTop() >= 3 {
|
||||
params := L.CheckTable(3)
|
||||
params.ForEach(func(k lua.LValue, v lua.LValue) {
|
||||
args = append(args, ConvertLuaTypesToGolang(v))
|
||||
})
|
||||
}
|
||||
|
||||
if conn.log {
|
||||
conn.logger.Info("DB QueryRow",
|
||||
slog.String("query", query),
|
||||
slog.Any("params", args))
|
||||
}
|
||||
|
||||
mtx := getDBMutex(conn.dbPath)
|
||||
mtx.RLock()
|
||||
defer mtx.RUnlock()
|
||||
|
||||
db, err := sql.Open("sqlite", conn.dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_sync=NORMAL&_cache_size=-10000")
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
row := db.QueryRow(query, args...)
|
||||
|
||||
columns := []string{}
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("prepare failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.Query(args...)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("query failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
defer rows.Close()
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("get columns failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
for _, c := range cols {
|
||||
columns = append(columns, c)
|
||||
}
|
||||
|
||||
colCount := len(columns)
|
||||
values := make([]any, colCount)
|
||||
valuePtrs := make([]any, colCount)
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
err = row.Scan(valuePtrs...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
L.Push(lua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("scan failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
rowTable := L.NewTable()
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if val == nil {
|
||||
L.SetField(rowTable, col, lua.LNil)
|
||||
} else {
|
||||
L.SetField(rowTable, col, ConvertGolangTypesToLua(L, val))
|
||||
}
|
||||
}
|
||||
|
||||
L.Push(rowTable)
|
||||
return 1
|
||||
}
|
||||
|
||||
func dbQuery(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
if !ok {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("invalid database connection"))
|
||||
return 2
|
||||
}
|
||||
|
||||
query := L.CheckString(2)
|
||||
|
||||
var args []any
|
||||
if L.GetTop() >= 3 {
|
||||
params := L.CheckTable(3)
|
||||
params.ForEach(func(k lua.LValue, v lua.LValue) {
|
||||
args = append(args, ConvertLuaTypesToGolang(v))
|
||||
})
|
||||
}
|
||||
|
||||
if conn.log {
|
||||
conn.logger.Info("DB Query",
|
||||
slog.String("query", query),
|
||||
slog.Any("params", args))
|
||||
}
|
||||
|
||||
mtx := getDBMutex(conn.dbPath)
|
||||
mtx.RLock()
|
||||
defer mtx.RUnlock()
|
||||
|
||||
db, err := sql.Open("sqlite", conn.dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_sync=NORMAL&_cache_size=-10000")
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("query failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("get columns failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
result := L.NewTable()
|
||||
colCount := len(columns)
|
||||
values := make([]any, colCount)
|
||||
valuePtrs := make([]any, colCount)
|
||||
|
||||
for rows.Next() {
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("scan failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
rowTable := L.NewTable()
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if val == nil {
|
||||
L.SetField(rowTable, col, lua.LNil)
|
||||
} else {
|
||||
L.SetField(rowTable, col, ConvertGolangTypesToLua(L, val))
|
||||
}
|
||||
}
|
||||
result.Append(rowTable)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("rows iteration failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
L.Push(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
func dbClose(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
if !ok {
|
||||
L.Push(lua.LFalse)
|
||||
L.Push(lua.LString("invalid database connection"))
|
||||
return 2
|
||||
}
|
||||
|
||||
close(conn.closeChan)
|
||||
L.Push(lua.LTrue)
|
||||
return 1
|
||||
}
|
||||
39
src/internal/server/sv1/handle.go
Normal file
39
src/internal/server/sv1/handle.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/rpc"
|
||||
)
|
||||
|
||||
func (h *HandlerV1) Handle(_ context.Context, sid string, r *http.Request, req *rpc.RPCRequest) *rpc.RPCResponse {
|
||||
if req.Method == "" {
|
||||
h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrMethodNotFoundS), slog.String("requested-method", req.Method))
|
||||
return rpc.NewError(rpc.ErrMethodIsMissing, rpc.ErrMethodIsMissingS, nil, req.ID)
|
||||
}
|
||||
|
||||
method, err := h.resolveMethodPath(req.Method)
|
||||
if err != nil {
|
||||
if err.Error() == rpc.ErrInvalidMethodFormatS {
|
||||
h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInvalidMethodFormatS), slog.String("requested-method", req.Method))
|
||||
return rpc.NewError(rpc.ErrInvalidMethodFormat, rpc.ErrInvalidMethodFormatS, nil, req.ID)
|
||||
} else if err.Error() == rpc.ErrMethodNotFoundS {
|
||||
h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrMethodNotFoundS), slog.String("requested-method", req.Method))
|
||||
return rpc.NewError(rpc.ErrMethodNotFound, rpc.ErrMethodNotFoundS, nil, req.ID)
|
||||
}
|
||||
}
|
||||
switch req.Params.(type) {
|
||||
case map[string]any, []any, nil:
|
||||
return h.handleLUA(sid, r, req, method)
|
||||
default:
|
||||
// JSON-RPC 2.0 Specification:
|
||||
// https://www.jsonrpc.org/specification#parameter_structures
|
||||
//
|
||||
// "params" MUST be either an *array* or an *object* if included.
|
||||
// Any other type (e.g., a number, string, or boolean) is INVALID.
|
||||
h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInvalidParamsS))
|
||||
return rpc.NewError(rpc.ErrInvalidParams, rpc.ErrInvalidParamsS, nil, req.ID)
|
||||
}
|
||||
}
|
||||
86
src/internal/server/sv1/jwt.go
Normal file
86
src/internal/server/sv1/jwt.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func loadJWTMod(llog *slog.Logger, sid string) func(*lua.LState) int {
|
||||
return func(L *lua.LState) int {
|
||||
llog.Debug("import module jwt")
|
||||
jwtMod := L.NewTable()
|
||||
|
||||
L.SetField(jwtMod, "encode", L.NewFunction(jwtEncode))
|
||||
L.SetField(jwtMod, "decode", L.NewFunction(jwtDecode))
|
||||
|
||||
L.SetField(jwtMod, "__seed", lua.LString(sid))
|
||||
L.Push(jwtMod)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func jwtEncode(L *lua.LState) int {
|
||||
payloadTbl := L.CheckTable(1)
|
||||
secret := L.GetField(payloadTbl, "secret").String()
|
||||
payload := L.GetField(payloadTbl, "payload").(*lua.LTable)
|
||||
expiresIn := L.GetField(payloadTbl, "expires_in")
|
||||
expDuration := time.Hour
|
||||
|
||||
if expiresIn.Type() == lua.LTNumber {
|
||||
floatVal := ConvertLuaTypesToGolang(expiresIn).(float64)
|
||||
expDuration = time.Duration(floatVal) * time.Second
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
payload.ForEach(func(key, value lua.LValue) {
|
||||
claims[key.String()] = ConvertLuaTypesToGolang(value)
|
||||
})
|
||||
claims["iat"] = time.Now().Unix()
|
||||
claims["exp"] = time.Now().Add(expDuration).Unix()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
L.Push(lua.LString(signedToken))
|
||||
return 1
|
||||
}
|
||||
|
||||
func jwtDecode(L *lua.LState) int {
|
||||
tokenString := L.CheckString(1)
|
||||
optsTbl := L.OptTable(2, L.NewTable())
|
||||
secret := L.GetField(optsTbl, "secret").String()
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
L.Push(lua.LString("Invalid token: " + err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
L.Push(lua.LString("Invalid claims"))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
|
||||
luaTable := L.NewTable()
|
||||
for k, v := range claims {
|
||||
luaTable.RawSetString(k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
|
||||
L.Push(lua.LNil)
|
||||
L.Push(luaTable)
|
||||
return 2
|
||||
}
|
||||
636
src/internal/server/sv1/lua_handler.go
Normal file
636
src/internal/server/sv1/lua_handler.go
Normal file
@@ -0,0 +1,636 @@
|
||||
package sv1
|
||||
|
||||
// TODO: make a lua state pool using sync.Pool
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/colors"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/rpc"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func addInitiatorHeaders(sid string, req *http.Request, headers http.Header) {
|
||||
clientIP := req.RemoteAddr
|
||||
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
||||
clientIP = forwardedFor
|
||||
}
|
||||
headers.Set("X-Initiator-IP", clientIP)
|
||||
headers.Set("X-Session-UUID", sid)
|
||||
headers.Set("X-Initiator-Host", req.Host)
|
||||
headers.Set("X-Initiator-User-Agent", req.UserAgent())
|
||||
headers.Set("X-Initiator-Referer", req.Referer())
|
||||
}
|
||||
|
||||
// A small reminder: this code is only at the MVP stage,
|
||||
// and some parts of the code may cause shock from the
|
||||
// incompetence of the developer. But, in the end,
|
||||
// this code is just an idea. If there is a desire to
|
||||
// contribute to the development of the code,
|
||||
// I will be only glad.
|
||||
// TODO: make this huge function more harmonious by dividing responsibilities
|
||||
func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse {
|
||||
var __exit = -1
|
||||
|
||||
llog := h.x.SLog.With(slog.String("session-id", sid))
|
||||
llog.Debug("handling LUA")
|
||||
L := lua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
osMod := L.GetGlobal("os").(*lua.LTable)
|
||||
L.SetField(osMod, "exit", lua.LNil)
|
||||
|
||||
ioMod := L.GetGlobal("io").(*lua.LTable)
|
||||
for _, k := range []string{"write", "output", "flush", "read", "input"} {
|
||||
ioMod.RawSetString(k, lua.LNil)
|
||||
}
|
||||
L.Env.RawSetString("print", lua.LNil)
|
||||
|
||||
for _, name := range []string{"stdout", "stderr", "stdin"} {
|
||||
stream := ioMod.RawGetString(name)
|
||||
if t, ok := stream.(*lua.LUserData); ok {
|
||||
t.Metatable = lua.LNil
|
||||
}
|
||||
}
|
||||
|
||||
seed := rand.Int()
|
||||
|
||||
loadSessionMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module session", slog.String("script", path))
|
||||
sessionMod := L.NewTable()
|
||||
inTable := L.NewTable()
|
||||
paramsTable := L.NewTable()
|
||||
headersTable := L.NewTable()
|
||||
|
||||
fetchedHeadersTable := L.NewTable()
|
||||
for k, v := range r.Header {
|
||||
L.SetField(fetchedHeadersTable, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
|
||||
headersGetter := L.NewFunction(func(L *lua.LState) int {
|
||||
path := L.OptString(1, "")
|
||||
def := L.Get(2)
|
||||
|
||||
get := func(path string) lua.LValue {
|
||||
if path == "" {
|
||||
return fetchedHeadersTable
|
||||
}
|
||||
fetched := r.Header.Get(path)
|
||||
if fetched == "" {
|
||||
return lua.LNil
|
||||
}
|
||||
return lua.LString(fetched)
|
||||
}
|
||||
val := get(path)
|
||||
if val == lua.LNil && def != lua.LNil {
|
||||
L.Push(def)
|
||||
} else {
|
||||
L.Push(val)
|
||||
}
|
||||
return 1
|
||||
})
|
||||
|
||||
L.SetField(headersTable, "__fetched", fetchedHeadersTable)
|
||||
|
||||
L.SetField(headersTable, "get", headersGetter)
|
||||
L.SetField(inTable, "headers", headersTable)
|
||||
|
||||
fetchedParamsTable := L.NewTable()
|
||||
switch params := req.Params.(type) {
|
||||
case map[string]any:
|
||||
for k, v := range params {
|
||||
L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
case []any:
|
||||
for i, v := range params {
|
||||
fetchedParamsTable.RawSetInt(i+1, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
}
|
||||
|
||||
paramsGetter := L.NewFunction(func(L *lua.LState) int {
|
||||
path := L.OptString(1, "")
|
||||
def := L.Get(2)
|
||||
|
||||
get := func(tbl *lua.LTable, path string) lua.LValue {
|
||||
if path == "" {
|
||||
return tbl
|
||||
}
|
||||
current := tbl
|
||||
parts := strings.Split(path, ".")
|
||||
size := len(parts)
|
||||
for index, key := range parts {
|
||||
val := current.RawGetString(key)
|
||||
if tblVal, ok := val.(*lua.LTable); ok {
|
||||
current = tblVal
|
||||
} else {
|
||||
if index == size-1 {
|
||||
return val
|
||||
}
|
||||
return lua.LNil
|
||||
}
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
paramsTbl := L.GetField(paramsTable, "__fetched") //
|
||||
val := get(paramsTbl.(*lua.LTable), path) //
|
||||
if val == lua.LNil && def != lua.LNil {
|
||||
L.Push(def)
|
||||
} else {
|
||||
L.Push(val)
|
||||
}
|
||||
return 1
|
||||
})
|
||||
L.SetField(paramsTable, "__fetched", fetchedParamsTable)
|
||||
|
||||
L.SetField(paramsTable, "get", paramsGetter)
|
||||
L.SetField(inTable, "params", paramsTable)
|
||||
|
||||
outTable := L.NewTable()
|
||||
scriptDataTable := L.NewTable()
|
||||
L.SetField(outTable, "__script_data", scriptDataTable)
|
||||
|
||||
L.SetField(inTable, "address", lua.LString(r.RemoteAddr))
|
||||
|
||||
L.SetField(sessionMod, "throw_error", L.NewFunction(func(L *lua.LState) int {
|
||||
arg := L.Get(1)
|
||||
var msg string
|
||||
switch arg.Type() {
|
||||
case lua.LTString:
|
||||
msg = arg.String()
|
||||
case lua.LTNumber:
|
||||
msg = strconv.FormatFloat(float64(arg.(lua.LNumber)), 'f', -1, 64)
|
||||
default:
|
||||
L.ArgError(1, "expected string or number")
|
||||
return 0
|
||||
}
|
||||
|
||||
L.RaiseError("%s", msg)
|
||||
return 0
|
||||
}))
|
||||
|
||||
resTable := L.NewTable()
|
||||
L.SetField(scriptDataTable, "result", resTable)
|
||||
L.SetField(outTable, "send", L.NewFunction(func(L *lua.LState) int {
|
||||
res := L.Get(1)
|
||||
|
||||
resFTable := scriptDataTable.RawGetString("result")
|
||||
if resPTable, ok := res.(*lua.LTable); ok {
|
||||
resPTable.ForEach(func(key, value lua.LValue) {
|
||||
L.SetField(resFTable, key.String(), value)
|
||||
})
|
||||
} else {
|
||||
L.SetField(scriptDataTable, "result", res)
|
||||
}
|
||||
|
||||
__exit = 0
|
||||
L.RaiseError("__successfull")
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(outTable, "set", L.NewFunction(func(L *lua.LState) int {
|
||||
res := L.Get(1)
|
||||
if res == lua.LNil {
|
||||
return 0
|
||||
}
|
||||
|
||||
resFTable := scriptDataTable.RawGetString("result")
|
||||
if resPTable, ok := res.(*lua.LTable); ok {
|
||||
resPTable.ForEach(func(key, value lua.LValue) {
|
||||
L.SetField(resFTable, key.String(), value)
|
||||
})
|
||||
} else {
|
||||
L.SetField(scriptDataTable, "result", res)
|
||||
}
|
||||
return 0
|
||||
}))
|
||||
|
||||
errTable := L.NewTable()
|
||||
L.SetField(scriptDataTable, "error", errTable)
|
||||
L.SetField(outTable, "send_error", L.NewFunction(func(L *lua.LState) int {
|
||||
var params [3]lua.LValue
|
||||
for i := range 3 {
|
||||
params[i] = L.Get(i + 1)
|
||||
}
|
||||
if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
for _, v := range params {
|
||||
switch v.Type() {
|
||||
case lua.LTNumber:
|
||||
if n, ok := v.(lua.LNumber); ok {
|
||||
L.SetField(errTable, "code", n)
|
||||
}
|
||||
case lua.LTString:
|
||||
if s, ok := v.(lua.LString); ok {
|
||||
L.SetField(errTable, "message", s)
|
||||
}
|
||||
case lua.LTTable:
|
||||
if tbl, ok := v.(*lua.LTable); ok {
|
||||
L.SetField(errTable, "data", tbl)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__exit = 1
|
||||
L.RaiseError("__unsuccessfull")
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(outTable, "set_error", L.NewFunction(func(L *lua.LState) int {
|
||||
var params [3]lua.LValue
|
||||
for i := range 3 {
|
||||
params[i] = L.Get(i + 1)
|
||||
}
|
||||
if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
for _, v := range params {
|
||||
switch v.Type() {
|
||||
case lua.LTNumber:
|
||||
if n, ok := v.(lua.LNumber); ok {
|
||||
L.SetField(errTable, "code", n)
|
||||
}
|
||||
case lua.LTString:
|
||||
if s, ok := v.(lua.LString); ok {
|
||||
L.SetField(errTable, "message", s)
|
||||
}
|
||||
case lua.LTTable:
|
||||
if tbl, ok := v.(*lua.LTable); ok {
|
||||
L.SetField(errTable, "data", tbl)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(sessionMod, "request", inTable)
|
||||
L.SetField(sessionMod, "response", outTable)
|
||||
|
||||
L.SetField(sessionMod, "id", lua.LString(sid))
|
||||
|
||||
L.SetField(sessionMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(sessionMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadLogMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module log", slog.String("script", path))
|
||||
logMod := L.NewTable()
|
||||
|
||||
logFuncs := map[string]func(string, ...any){
|
||||
"info": llog.Info,
|
||||
"debug": llog.Debug,
|
||||
"error": llog.Error,
|
||||
"warn": llog.Warn,
|
||||
}
|
||||
|
||||
for name, logFunc := range logFuncs {
|
||||
fun := logFunc
|
||||
L.SetField(logMod, name, L.NewFunction(func(L *lua.LState) int {
|
||||
msg := L.Get(1)
|
||||
converted := ConvertLuaTypesToGolang(msg)
|
||||
fun(fmt.Sprintf("the script says: %s", converted), slog.String("script", path))
|
||||
return 0
|
||||
}))
|
||||
}
|
||||
|
||||
for _, fn := range []struct {
|
||||
field string
|
||||
pfunc func(string, ...any)
|
||||
color func() string
|
||||
}{
|
||||
{"event", h.x.Log.Printf, nil},
|
||||
{"event_error", h.x.Log.Printf, colors.PrintError},
|
||||
{"event_warn", h.x.Log.Printf, colors.PrintWarn},
|
||||
} {
|
||||
L.SetField(logMod, fn.field, L.NewFunction(func(L *lua.LState) int {
|
||||
msg := L.Get(1)
|
||||
converted := ConvertLuaTypesToGolang(msg)
|
||||
if fn.color != nil {
|
||||
h.x.Log.Printf("%s: %s: %s", fn.color(), path, converted)
|
||||
} else {
|
||||
h.x.Log.Printf("%s: %s", path, converted)
|
||||
}
|
||||
return 0
|
||||
}))
|
||||
}
|
||||
|
||||
L.SetField(logMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(logMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadNetMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module net", slog.String("script", path))
|
||||
netMod := L.NewTable()
|
||||
netModhttp := L.NewTable()
|
||||
|
||||
L.SetField(netModhttp, "get_request", L.NewFunction(func(L *lua.LState) int {
|
||||
logRequest := L.ToBool(1)
|
||||
url := L.ToString(2)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
addInitiatorHeaders(sid, r, req.Header)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
if logRequest {
|
||||
llog.Info("HTTP GET request",
|
||||
slog.String("script", path),
|
||||
slog.String("url", url),
|
||||
slog.Int("status", resp.StatusCode),
|
||||
slog.String("status_text", resp.Status),
|
||||
slog.String("initiator_ip", req.Header.Get("X-Initiator-IP")),
|
||||
)
|
||||
}
|
||||
|
||||
result := L.NewTable()
|
||||
L.SetField(result, "status", lua.LNumber(resp.StatusCode))
|
||||
L.SetField(result, "status_text", lua.LString(resp.Status))
|
||||
L.SetField(result, "body", lua.LString(body))
|
||||
L.SetField(result, "content_length", lua.LNumber(resp.ContentLength))
|
||||
|
||||
headers := L.NewTable()
|
||||
for k, v := range resp.Header {
|
||||
L.SetField(headers, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
L.SetField(result, "headers", headers)
|
||||
|
||||
L.Push(result)
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(netModhttp, "post_request", L.NewFunction(func(L *lua.LState) int {
|
||||
logRequest := L.ToBool(1)
|
||||
url := L.ToString(2)
|
||||
contentType := L.ToString(3)
|
||||
payload := L.ToString(4)
|
||||
|
||||
body := strings.NewReader(payload)
|
||||
|
||||
req, err := http.NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
|
||||
addInitiatorHeaders(sid, r, req.Header)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
if logRequest {
|
||||
llog.Info("HTTP POST request",
|
||||
slog.String("script", path),
|
||||
slog.String("url", url),
|
||||
slog.String("content_type", contentType),
|
||||
slog.Int("status", resp.StatusCode),
|
||||
slog.String("status_text", resp.Status),
|
||||
slog.String("initiator_ip", req.Header.Get("X-Initiator-IP")),
|
||||
)
|
||||
}
|
||||
|
||||
result := L.NewTable()
|
||||
L.SetField(result, "status", lua.LNumber(resp.StatusCode))
|
||||
L.SetField(result, "status_text", lua.LString(resp.Status))
|
||||
L.SetField(result, "body", lua.LString(respBody))
|
||||
L.SetField(result, "content_length", lua.LNumber(resp.ContentLength))
|
||||
|
||||
headers := L.NewTable()
|
||||
for k, v := range resp.Header {
|
||||
L.SetField(headers, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
L.SetField(result, "headers", headers)
|
||||
|
||||
L.Push(result)
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(netMod, "http", netModhttp)
|
||||
|
||||
L.SetField(netMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(netMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadCryptbcryptMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module crypt.bcrypt", slog.String("script", path))
|
||||
bcryptMod := L.NewTable()
|
||||
|
||||
L.SetField(bcryptMod, "MinCost", lua.LNumber(bcrypt.MinCost))
|
||||
L.SetField(bcryptMod, "MaxCost", lua.LNumber(bcrypt.MaxCost))
|
||||
L.SetField(bcryptMod, "DefaultCost", lua.LNumber(bcrypt.DefaultCost))
|
||||
|
||||
L.SetField(bcryptMod, "generate", L.NewFunction(func(l *lua.LState) int {
|
||||
password := ConvertLuaTypesToGolang(L.Get(1))
|
||||
passwordStr, ok := password.(string)
|
||||
if !ok {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("error: password must be a string"))
|
||||
return 2
|
||||
}
|
||||
|
||||
cost := ConvertLuaTypesToGolang(L.Get(2))
|
||||
costInt := bcrypt.DefaultCost
|
||||
switch v := cost.(type) {
|
||||
case int:
|
||||
costInt = v
|
||||
case float64:
|
||||
costInt = int(v)
|
||||
case nil:
|
||||
// ok, use DefaultCost
|
||||
default:
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("error: cost must be an integer"))
|
||||
return 2
|
||||
}
|
||||
|
||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(passwordStr), costInt)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("error: " + err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
L.Push(lua.LString(string(hashBytes)))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}))
|
||||
|
||||
L.SetField(bcryptMod, "compare", L.NewFunction(func(l *lua.LState) int {
|
||||
hash := ConvertLuaTypesToGolang(L.Get(1))
|
||||
hashStr, ok := hash.(string)
|
||||
if !ok {
|
||||
L.Push(lua.LString("error: hash must be a string"))
|
||||
return 1
|
||||
}
|
||||
password := ConvertLuaTypesToGolang(L.Get(2))
|
||||
passwordStr, ok := password.(string)
|
||||
if !ok {
|
||||
L.Push(lua.LString("error: password must be a string"))
|
||||
return 1
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashStr), []byte(passwordStr))
|
||||
if err != nil {
|
||||
L.Push(lua.LFalse)
|
||||
return 1
|
||||
}
|
||||
L.Push(lua.LTrue)
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(bcryptMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(bcryptMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadCryptbsha256Mod := func(L *lua.LState) int {
|
||||
llog.Debug("import module crypt.sha256", slog.String("script", path))
|
||||
sha265mod := L.NewTable()
|
||||
|
||||
L.SetField(sha265mod, "hash", L.NewFunction(func(l *lua.LState) int {
|
||||
data := ConvertLuaTypesToGolang(L.Get(1))
|
||||
var dataStr = fmt.Sprint(data)
|
||||
|
||||
hash := sha256.Sum256([]byte(dataStr))
|
||||
|
||||
L.Push(lua.LString(hex.EncodeToString(hash[:])))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}))
|
||||
|
||||
L.SetField(sha265mod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(sha265mod)
|
||||
return 1
|
||||
}
|
||||
|
||||
L.PreloadModule("internal.session", loadSessionMod)
|
||||
L.PreloadModule("internal.log", loadLogMod)
|
||||
L.PreloadModule("internal.net", loadNetMod)
|
||||
L.PreloadModule("internal.database.sqlite", loadDBMod(llog, fmt.Sprint(seed)))
|
||||
L.PreloadModule("internal.crypt.bcrypt", loadCryptbcryptMod)
|
||||
L.PreloadModule("internal.crypt.sha256", loadCryptbsha256Mod)
|
||||
L.PreloadModule("internal.crypt.jwt", loadJWTMod(llog, fmt.Sprint(seed)))
|
||||
|
||||
llog.Debug("preparing environment")
|
||||
prep := filepath.Join(*h.x.Config.Conf.Node.ComDir, "_prepare.lua")
|
||||
if _, err := os.Stat(prep); err == nil {
|
||||
if err := L.DoFile(prep); err != nil {
|
||||
llog.Error("script error", slog.String("script", path), slog.String("error", err.Error()))
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
}
|
||||
llog.Debug("executing script", slog.String("script", path))
|
||||
err := L.DoFile(path)
|
||||
if err != nil && __exit != 0 && __exit != 1 {
|
||||
llog.Error("script error", slog.String("script", path), slog.String("error", err.Error()))
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
|
||||
pkg := L.GetGlobal("package")
|
||||
pkgTbl, ok := pkg.(*lua.LTable)
|
||||
if !ok {
|
||||
llog.Error("script error", slog.String("script", path), slog.String("error", "package not found"))
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
|
||||
loaded := pkgTbl.RawGetString("loaded")
|
||||
loadedTbl, ok := loaded.(*lua.LTable)
|
||||
if !ok {
|
||||
llog.Error("script error", slog.String("script", path), slog.String("error", "package.loaded not found"))
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
|
||||
sessionVal := loadedTbl.RawGetString("internal.session")
|
||||
sessionTbl, ok := sessionVal.(*lua.LTable)
|
||||
if !ok {
|
||||
return rpc.NewResponse(nil, req.ID)
|
||||
}
|
||||
|
||||
tag := sessionTbl.RawGetString("__seed")
|
||||
if tag.Type() != lua.LTString || tag.String() != fmt.Sprint(seed) {
|
||||
llog.Debug("stock session module is not imported: wrong seed", slog.String("script", path))
|
||||
return rpc.NewResponse(nil, req.ID)
|
||||
}
|
||||
|
||||
outVal := sessionTbl.RawGetString("response")
|
||||
outTbl, ok := outVal.(*lua.LTable)
|
||||
if !ok {
|
||||
llog.Error("script error", slog.String("script", path), slog.String("error", "response is not a table"))
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
|
||||
if scriptDataTable, ok := outTbl.RawGetString("__script_data").(*lua.LTable); ok {
|
||||
switch __exit {
|
||||
case 1:
|
||||
if errTbl, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
llog.Debug("catch error table", slog.String("script", path))
|
||||
code := rpc.ErrInternalError
|
||||
message := rpc.ErrInternalErrorS
|
||||
if c := errTbl.RawGetString("code"); c.Type() == lua.LTNumber {
|
||||
code = int(c.(lua.LNumber))
|
||||
}
|
||||
if msg := errTbl.RawGetString("message"); msg.Type() == lua.LTString {
|
||||
message = msg.String()
|
||||
}
|
||||
data := ConvertLuaTypesToGolang(errTbl.RawGetString("data"))
|
||||
llog.Error("the script terminated with an error", slog.Int("code", code), slog.String("message", message), slog.Any("data", data))
|
||||
return rpc.NewError(code, message, data, req.ID)
|
||||
}
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
case 0:
|
||||
resVal := ConvertLuaTypesToGolang(scriptDataTable.RawGetString("result"))
|
||||
return rpc.NewResponse(resVal, req.ID)
|
||||
}
|
||||
}
|
||||
return rpc.NewResponse(nil, req.ID)
|
||||
}
|
||||
126
src/internal/server/sv1/lua_types.go
Normal file
126
src/internal/server/sv1/lua_types.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func ConvertLuaTypesToGolang(value lua.LValue) any {
|
||||
switch value.Type() {
|
||||
case lua.LTString:
|
||||
return value.String()
|
||||
case lua.LTNumber:
|
||||
return float64(value.(lua.LNumber))
|
||||
case lua.LTBool:
|
||||
return bool(value.(lua.LBool))
|
||||
case lua.LTTable:
|
||||
tbl := value.(*lua.LTable)
|
||||
|
||||
maxIdx := 0
|
||||
isArray := true
|
||||
|
||||
var isNumeric = false
|
||||
tbl.ForEach(func(key, _ lua.LValue) {
|
||||
var numKey lua.LValue
|
||||
var ok bool
|
||||
switch key.Type() {
|
||||
case lua.LTString:
|
||||
numKey, ok = key.(lua.LString)
|
||||
if !ok {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
case lua.LTNumber:
|
||||
numKey, ok = key.(lua.LNumber)
|
||||
if !ok {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
isNumeric = true
|
||||
}
|
||||
|
||||
num, err := strconv.Atoi(numKey.String())
|
||||
if err != nil {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
if num < 1 {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
if num > maxIdx {
|
||||
maxIdx = num
|
||||
}
|
||||
})
|
||||
|
||||
if isArray {
|
||||
arr := make([]any, maxIdx)
|
||||
if isNumeric {
|
||||
for i := 1; i <= maxIdx; i++ {
|
||||
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetInt(i))
|
||||
}
|
||||
} else {
|
||||
for i := 1; i <= maxIdx; i++ {
|
||||
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetString(strconv.Itoa(i)))
|
||||
}
|
||||
}
|
||||
return arr
|
||||
}
|
||||
result := make(map[string]any)
|
||||
tbl.ForEach(func(key, val lua.LValue) {
|
||||
result[key.String()] = ConvertLuaTypesToGolang(val)
|
||||
})
|
||||
return result
|
||||
|
||||
case lua.LTNil:
|
||||
return nil
|
||||
default:
|
||||
return value.String()
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertGolangTypesToLua(L *lua.LState, val any) lua.LValue {
|
||||
if val == nil {
|
||||
return lua.LNil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(val)
|
||||
rt := rv.Type()
|
||||
|
||||
switch rt.Kind() {
|
||||
case reflect.String:
|
||||
return lua.LString(rv.String())
|
||||
case reflect.Bool:
|
||||
return lua.LBool(rv.Bool())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return lua.LNumber(rv.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return lua.LNumber(rv.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return lua.LNumber(rv.Float())
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
tbl := L.NewTable()
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, rv.Index(i).Interface()))
|
||||
}
|
||||
return tbl
|
||||
|
||||
case reflect.Map:
|
||||
if rt.Key().Kind() == reflect.String {
|
||||
tbl := L.NewTable()
|
||||
for _, key := range rv.MapKeys() {
|
||||
val := rv.MapIndex(key)
|
||||
tbl.RawSetString(key.String(), ConvertGolangTypesToLua(L, val.Interface()))
|
||||
}
|
||||
return tbl
|
||||
}
|
||||
|
||||
default:
|
||||
return lua.LString(fmt.Sprintf("%v", val))
|
||||
}
|
||||
return lua.LString(fmt.Sprintf("%v", val))
|
||||
}
|
||||
28
src/internal/server/sv1/path.go
Normal file
28
src/internal/server/sv1/path.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/rpc"
|
||||
)
|
||||
|
||||
var RPCMethodSeparator = "."
|
||||
|
||||
func (h *HandlerV1) resolveMethodPath(method string) (string, error) {
|
||||
if !h.allowedCmd.MatchString(method) {
|
||||
return "", errors.New(rpc.ErrInvalidMethodFormatS)
|
||||
}
|
||||
|
||||
parts := strings.Split(method, RPCMethodSeparator)
|
||||
relPath := filepath.Join(parts...) + ".lua"
|
||||
fullPath := filepath.Join(*h.x.Config.Conf.Node.ComDir, relPath)
|
||||
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
return "", errors.New(rpc.ErrMethodNotFoundS)
|
||||
}
|
||||
|
||||
return fullPath, nil
|
||||
}
|
||||
47
src/internal/server/sv1/server.go
Normal file
47
src/internal/server/sv1/server.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Package sv1 provides the implementation of the Server V1 API handler.
|
||||
// It includes utilities for handling API requests, extracting descriptions, and managing UUIDs.
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/engine/app"
|
||||
)
|
||||
|
||||
// HandlerV1InitStruct structure is only for initialization
|
||||
type HandlerV1InitStruct struct {
|
||||
Ver string
|
||||
CS *corestate.CoreState
|
||||
X *app.AppX
|
||||
AllowedCmd *regexp.Regexp
|
||||
}
|
||||
|
||||
// HandlerV1 implements the ServerV1UtilsContract and serves as the main handler for API requests.
|
||||
type HandlerV1 struct {
|
||||
cs *corestate.CoreState
|
||||
x *app.AppX
|
||||
|
||||
// allowedCmd and listAllowedCmd are regular expressions used to validate command names.
|
||||
allowedCmd *regexp.Regexp
|
||||
|
||||
ver string
|
||||
}
|
||||
|
||||
// InitV1Server initializes a new HandlerV1 with the provided configuration and returns it.
|
||||
// Should be carefull with giving to this function invalid parameters,
|
||||
// because there is no validation of parameters in this function.
|
||||
func InitV1Server(o *HandlerV1InitStruct) *HandlerV1 {
|
||||
return &HandlerV1{
|
||||
cs: o.CS,
|
||||
x: o.X,
|
||||
allowedCmd: o.AllowedCmd,
|
||||
ver: o.Ver,
|
||||
}
|
||||
}
|
||||
|
||||
// GetVersion returns the API version of the HandlerV1, which is set during initialization.
|
||||
// This version is used to identify the API version in the request routing.
|
||||
func (h *HandlerV1) GetVersion() string {
|
||||
return h.ver
|
||||
}
|
||||
Reference in New Issue
Block a user