mirror of
https://github.com/akyaiy/GoSally-mvp.git
synced 2026-01-03 19:52:25 +00:00
Compare commits
14 Commits
5c01eaad6f
...
auth-serve
| Author | SHA1 | Date | |
|---|---|---|---|
| 625e5daf71 | |||
| cc27843bb3 | |||
| 20fec82159 | |||
| 055b299ecb | |||
| 17bf207087 | |||
| 7ae8e12dc8 | |||
| 6e36db428a | |||
| 06103a3264 | |||
| c6da55ad65 | |||
| 20a1e3e7bb | |||
| e594d519a7 | |||
| 2ceb236a53 | |||
| 811403a0a2 | |||
| b451f2d3fc |
@@ -1,3 +1,5 @@
|
||||
// The cmd package is the main package where all the main hooks and methods are called.
|
||||
// GoSally uses spf13/cobra to organize all the calls.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
@@ -22,6 +24,8 @@ scripts in a given directory. For more information, visit: https://gosally.oblat
|
||||
},
|
||||
}
|
||||
|
||||
// Execute prepares global log, loads cmdline args
|
||||
// and executes rootCmd.Execute()
|
||||
func Execute() {
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetPrefix(colors.SetBrightBlack(fmt.Sprintf("(%s) ", corestate.StageNotReady)))
|
||||
|
||||
@@ -11,6 +11,7 @@ var runCmd = &cobra.Command{
|
||||
Short: "Run node normally",
|
||||
Long: `
|
||||
"run" starts the node with settings depending on the configuration file`,
|
||||
// hooks.Run essentially the heart of the program
|
||||
Run: hooks.Run,
|
||||
}
|
||||
|
||||
|
||||
77
com/Access/GetMasterAccess.lua
Normal file
77
com/Access/GetMasterAccess.lua
Normal file
@@ -0,0 +1,77 @@
|
||||
local session = require("internal.session")
|
||||
local log = require("internal.log")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
local bc = require("internal.crypt.bcrypt")
|
||||
local db = require("internal.database.sqlite").connect("db/root.db", {log = true})
|
||||
local sha256 = require("internal.crypt.sha256")
|
||||
|
||||
log.info("Someone at "..session.request.address.." trying to get master access")
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
db:close()
|
||||
db = nil
|
||||
end
|
||||
end
|
||||
|
||||
local params = session.request.params.get()
|
||||
|
||||
local function check_missing(arr, p)
|
||||
local is_missing = {}
|
||||
local ok = true
|
||||
for _, key in ipairs(arr) do
|
||||
if p[key] == nil then
|
||||
table.insert(is_missing, key)
|
||||
ok = false
|
||||
end
|
||||
end
|
||||
return ok, is_missing
|
||||
end
|
||||
|
||||
local ok, mp = check_missing({"master_secret", "master_name", "my_key"}, params)
|
||||
if not ok then
|
||||
close_db()
|
||||
session.response.send_error(-32602, "Missing params", mp)
|
||||
end
|
||||
|
||||
if type(params.master_secret) ~= "string" then
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
if type(params.master_name) ~= "string" then
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local master, err = db:query_row("SELECT * FROM master_units WHERE master_name = ?", {params.master_name})
|
||||
|
||||
if not master then
|
||||
log.event("DB query failed:", err)
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local ok = bc.compare(master.master_secret, params.master_secret)
|
||||
if not ok then
|
||||
log.warn("Login failed: wrong password")
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local token = jwt.encode({
|
||||
secret = require("_config").token(),
|
||||
payload = {
|
||||
session_uuid = session.id,
|
||||
master_id = master.id,
|
||||
key = sha256.sum(params.my_key)
|
||||
},
|
||||
expires_in = 3600
|
||||
})
|
||||
|
||||
close_db()
|
||||
session.response.send({
|
||||
token = token
|
||||
})
|
||||
|
||||
-- G7HgOgl72o7t7u7r
|
||||
11
com/Echo.lua
Normal file
11
com/Echo.lua
Normal file
@@ -0,0 +1,11 @@
|
||||
local s = require("internal.session")
|
||||
|
||||
if not s.request.params.__fetched.data then
|
||||
s.response.error = {
|
||||
code = 123,
|
||||
message = "params.data is missing"
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
s.response.send(s.request.params.__fetched)
|
||||
@@ -52,6 +52,6 @@ end
|
||||
local basePath = "com"
|
||||
local layer = params.layer and params.layer:gsub(">", "/") or nil
|
||||
|
||||
session.response.result = {
|
||||
session.response.send({
|
||||
answer = layer and scanDirectory(basePath, layer) or scanDirectory(basePath, "")
|
||||
}
|
||||
})
|
||||
69
com/Zones/GetZoneInfo.lua
Normal file
69
com/Zones/GetZoneInfo.lua
Normal file
@@ -0,0 +1,69 @@
|
||||
local session = require("internal.session")
|
||||
local log = require("internal.log")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
local bc = require("internal.crypt.bcrypt")
|
||||
local sha256 = require("internal.crypt.sha256")
|
||||
local dbdriver = require("internal.database.sqlite")
|
||||
|
||||
local db_root = dbdriver.connect("db/root.db", {log = true})
|
||||
local db_zone = nil
|
||||
|
||||
local function close_db()
|
||||
if db_root then
|
||||
db_root:close()
|
||||
db_root = nil
|
||||
end
|
||||
if db_zone then
|
||||
db_zone:close()
|
||||
db_zone = nil
|
||||
end
|
||||
end
|
||||
|
||||
local token = session.request.headers.get("authorization")
|
||||
|
||||
if not token or type(token) ~= "string" then
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local prefix = "Bearer "
|
||||
if token:sub(1, #prefix) ~= prefix then
|
||||
close_db()
|
||||
session.response.send_error(-32052, "Invalid Authorization scheme")
|
||||
end
|
||||
|
||||
local access_token = token:sub(#prefix + 1)
|
||||
|
||||
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
|
||||
|
||||
if err or not data then
|
||||
close_db()
|
||||
session.response.send_error(-32053, "Cannod parse JWT", {err})
|
||||
end
|
||||
|
||||
if data.master_id then
|
||||
|
||||
|
||||
end
|
||||
|
||||
local params = session.request.params.get()
|
||||
|
||||
local function check_missing(arr, p)
|
||||
local is_missing = {}
|
||||
local ok = true
|
||||
for _, key in ipairs(arr) do
|
||||
if p[key] == nil then
|
||||
table.insert(is_missing, key)
|
||||
ok = false
|
||||
end
|
||||
end
|
||||
return ok, is_missing
|
||||
end
|
||||
|
||||
local ok, mp = check_missing({"zone_name"}, params)
|
||||
if not ok then
|
||||
close_db()
|
||||
session.response.send_error(-32602, "Missing params", mp)
|
||||
end
|
||||
|
||||
close_db()
|
||||
3
go.mod
3
go.mod
@@ -19,7 +19,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
@@ -41,6 +41,5 @@ require (
|
||||
|
||||
require (
|
||||
github.com/go-chi/cors v1.2.2
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
// The config composer needs to be in the global scope
|
||||
var Compositor *config.Compositor = config.NewCompositor()
|
||||
|
||||
func InitGlobalLoggerHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
@@ -65,7 +66,8 @@ func InitConfigLoadHook(_ context.Context, cs *corestate.CoreState, x *app.AppX)
|
||||
}
|
||||
}
|
||||
|
||||
func InitUUUDHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
// The hook reads or prepares a persistent uuid for the node
|
||||
func InitUUIDHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
uuid32, err := corestate.GetNodeUUID(filepath.Join(cs.MetaDir, "uuid"))
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
if err := corestate.SetNodeUUID(filepath.Join(cs.NodePath, cs.MetaDir, cs.UUID32DirName)); err != nil {
|
||||
@@ -80,8 +82,11 @@ func InitUUUDHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
x.Log.Fatalf("uuid load error: %s", err)
|
||||
}
|
||||
cs.UUID32 = uuid32
|
||||
corestate.NODE_UUID = uuid32
|
||||
}
|
||||
|
||||
// The hook is responsible for checking the initialization stage
|
||||
// and restarting in some cases
|
||||
func InitRuntimeHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
if *x.Config.Env.ParentStagePID != os.Getpid() {
|
||||
// still pre-init stage
|
||||
@@ -133,6 +138,8 @@ func InitRuntimeHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
}
|
||||
|
||||
// post-init stage
|
||||
// The hook creates a run.lock file, which contains information
|
||||
// about the process and the node, in the runtime directory.
|
||||
func InitRunlockHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
NodeApp.Fallback(func(ctx context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
x.Log.Println("Cleaning up...")
|
||||
@@ -185,6 +192,8 @@ func InitRunlockHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
}
|
||||
}
|
||||
|
||||
// The hook reads the configuration and replaces special expressions
|
||||
// (%tmp% and so on) in string fields with the required data.
|
||||
func InitConfigReplHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
if !slices.Contains(*x.Config.Conf.DisableWarnings, "--WNonStdTmpDir") && os.TempDir() != "/tmp" {
|
||||
x.Log.Printf("%s: %s", colors.PrintWarn(), "Non-standard value specified for temporary directory")
|
||||
@@ -209,6 +218,8 @@ func InitConfigReplHook(_ context.Context, cs *corestate.CoreState, x *app.AppX)
|
||||
}
|
||||
}
|
||||
|
||||
// The hook is responsible for outputting the
|
||||
// final config and asking for confirmation.
|
||||
func InitConfigPrintHook(ctx context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
if *x.Config.Conf.Node.ShowConfig {
|
||||
fmt.Printf("Configuration from %s:\n", x.Config.CMDLine.Run.ConfigPath)
|
||||
@@ -239,6 +250,8 @@ func InitSLogHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
*x.SLog = *newSlog
|
||||
}
|
||||
|
||||
// The method goes through the entire config structure through
|
||||
// reflection and replaces string fields with the required ones.
|
||||
func processConfig(conf any, replacements map[string]any) error {
|
||||
val := reflect.ValueOf(conf)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
|
||||
@@ -33,7 +33,7 @@ var NodeApp = app.New()
|
||||
func Run(cmd *cobra.Command, args []string) {
|
||||
NodeApp.InitialHooks(
|
||||
InitGlobalLoggerHook, InitCorestateHook, InitConfigLoadHook,
|
||||
InitUUUDHook, InitRuntimeHook, InitRunlockHook,
|
||||
InitUUIDHook, InitRuntimeHook, InitRunlockHook,
|
||||
InitConfigReplHook, InitConfigPrintHook, InitSLogHook,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package corestate
|
||||
|
||||
var NODE_UUID string
|
||||
|
||||
type Stage string
|
||||
|
||||
const (
|
||||
|
||||
1
internal/engine/lua/handler.go
Normal file
1
internal/engine/lua/handler.go
Normal file
@@ -0,0 +1 @@
|
||||
package lua
|
||||
35
internal/engine/lua/pool.go
Normal file
35
internal/engine/lua/pool.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package lua
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
type LuaPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func NewLuaPool() *LuaPool {
|
||||
return &LuaPool{
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
L := lua.NewState()
|
||||
|
||||
return L
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (lp *LuaPool) Get() *lua.LState {
|
||||
return lp.pool.Get().(*lua.LState)
|
||||
}
|
||||
|
||||
func (lp *LuaPool) Put(L *lua.LState) {
|
||||
L.Close()
|
||||
|
||||
newL := lua.NewState()
|
||||
|
||||
lp.pool.Put(newL)
|
||||
}
|
||||
26
internal/engine/lua/types.go
Normal file
26
internal/engine/lua/types.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package lua
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/engine/app"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/rpc"
|
||||
)
|
||||
|
||||
type LuaEngineDeps struct {
|
||||
HttpRequest *http.Request
|
||||
JSONRPCRequest *rpc.RPCRequest
|
||||
SessionUUID string
|
||||
ScriptPath string
|
||||
}
|
||||
|
||||
type LuaEngineContract interface {
|
||||
Handle(deps *LuaEngineDeps) *rpc.RPCResponse
|
||||
|
||||
}
|
||||
|
||||
type LuaEngine struct {
|
||||
x *app.AppX
|
||||
cs *corestate.CoreState
|
||||
}
|
||||
@@ -20,18 +20,14 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
sessionUUID := r.Header.Get("X-Session-UUID")
|
||||
if sessionUUID == "" {
|
||||
sessionUUID = uuid.New().String()
|
||||
|
||||
}
|
||||
gs.x.SLog.Debug("new request", slog.String("session-uuid", sessionUUID), slog.Group("connection", slog.String("ip", r.RemoteAddr)))
|
||||
|
||||
w.Header().Set("X-Session-UUID", sessionUUID)
|
||||
if !gs.sm.Add(sessionUUID) {
|
||||
gs.x.SLog.Debug("session is busy", slog.String("session-uuid", sessionUUID))
|
||||
rpc.WriteError(gs.cs.UUID32, w, &rpc.RPCResponse{
|
||||
Error: map[string]any{
|
||||
"code": rpc.ErrSessionIsBusy,
|
||||
"message": rpc.ErrSessionIsBusyS,
|
||||
},
|
||||
})
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrSessionIsBusy, rpc.ErrSessionIsBusyS, nil, nil))
|
||||
return
|
||||
}
|
||||
defer gs.sm.Delete(sessionUUID)
|
||||
@@ -40,14 +36,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
gs.x.SLog.Debug("failed to read body", slog.String("err", err.Error()))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
rpc.WriteError(gs.cs.UUID32, w, &rpc.RPCResponse{
|
||||
JSONRPC: rpc.JSONRPCVersion,
|
||||
ID: nil,
|
||||
Error: map[string]any{
|
||||
"code": rpc.ErrInternalError,
|
||||
"message": rpc.ErrInternalErrorS,
|
||||
},
|
||||
})
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, nil))
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInternalErrorS))
|
||||
return
|
||||
}
|
||||
@@ -60,14 +49,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.Unmarshal(body, &single); err != nil {
|
||||
gs.x.SLog.Debug("failed to parse json", slog.String("err", err.Error()))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
rpc.WriteError(gs.cs.UUID32, w, &rpc.RPCResponse{
|
||||
JSONRPC: rpc.JSONRPCVersion,
|
||||
ID: nil,
|
||||
Error: map[string]any{
|
||||
"code": rpc.ErrParseError,
|
||||
"message": rpc.ErrParseErrorS,
|
||||
},
|
||||
})
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrParseError, rpc.ErrParseErrorS, nil, nil))
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrParseErrorS))
|
||||
return
|
||||
}
|
||||
@@ -76,7 +58,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(""))
|
||||
return
|
||||
}
|
||||
rpc.WriteResponse(gs.cs.UUID32, w, resp)
|
||||
rpc.WriteResponse(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -13,11 +13,16 @@ type RPCRequest struct {
|
||||
type RPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID *json.RawMessage `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Result any `json:"result,omitzero"`
|
||||
Error any `json:"error,omitzero"`
|
||||
Data *RPCData `json:"data,omitzero"`
|
||||
}
|
||||
|
||||
type RPCData struct {
|
||||
ResponsibleNode string `json:"responsible-node,omitempty"`
|
||||
Salt string `json:"salt,omitempty"`
|
||||
Checksum string `json:"checksum-md5,omitempty"`
|
||||
NewSessionUUID string `json:"new-session-uuid,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -1,39 +1,58 @@
|
||||
package rpc
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func generateChecksum(result any) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x", md5.Sum(data))
|
||||
}
|
||||
|
||||
func generateSalt() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
func GetData(data any) *RPCData {
|
||||
return &RPCData{
|
||||
Salt: generateSalt(),
|
||||
ResponsibleNode: corestate.NODE_UUID,
|
||||
Checksum: generateChecksum(data),
|
||||
}
|
||||
}
|
||||
|
||||
func NewError(code int, message string, data any, id *json.RawMessage) *RPCResponse {
|
||||
if data != nil {
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Error: map[string]any{
|
||||
Error := make(map[string]any)
|
||||
Error = map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Error: map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
Error: Error,
|
||||
Data: GetData(Error),
|
||||
}
|
||||
}
|
||||
|
||||
func NewResponse(result any, id *json.RawMessage) *RPCResponse {
|
||||
if result == nil {
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
}
|
||||
}
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Result: result,
|
||||
Data: GetData(result),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,40 +1,11 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func generateChecksum(result any) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x", md5.Sum(data))
|
||||
}
|
||||
|
||||
func generateSalt() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
func write(nid string, w http.ResponseWriter, msg *RPCResponse) error {
|
||||
msg.Salt = generateSalt()
|
||||
if msg.Result != nil {
|
||||
msg.Checksum = generateChecksum(msg.Result)
|
||||
} else if msg.Error != nil {
|
||||
msg.Checksum = generateChecksum(msg.Error)
|
||||
}
|
||||
|
||||
if nid != "" {
|
||||
msg.ResponsibleNode = nid
|
||||
}
|
||||
func write(w http.ResponseWriter, msg *RPCResponse) error {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -43,10 +14,10 @@ func write(nid string, w http.ResponseWriter, msg *RPCResponse) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func WriteError(nid string, w http.ResponseWriter, errm *RPCResponse) error {
|
||||
return write(nid, w, errm)
|
||||
func WriteError(w http.ResponseWriter, errm *RPCResponse) error {
|
||||
return write(w, errm)
|
||||
}
|
||||
|
||||
func WriteResponse(nid string, w http.ResponseWriter, response *RPCResponse) error {
|
||||
return write(nid, w, response)
|
||||
func WriteResponse(w http.ResponseWriter, response *RPCResponse) error {
|
||||
return write(w, response)
|
||||
}
|
||||
|
||||
@@ -82,6 +82,7 @@ func loadDBMod(llog *slog.Logger, sid string) func(*lua.LState) int {
|
||||
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
|
||||
"exec": dbExec,
|
||||
"query": dbQuery,
|
||||
"query_row": dbQueryRow,
|
||||
"close": dbClose,
|
||||
}))
|
||||
|
||||
@@ -213,6 +214,102 @@ func dbExec(L *lua.LState) int {
|
||||
return 2
|
||||
}
|
||||
|
||||
func dbQueryRow(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
if !ok {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("invalid database connection"))
|
||||
return 2
|
||||
}
|
||||
|
||||
query := L.CheckString(2)
|
||||
|
||||
var args []any
|
||||
if L.GetTop() >= 3 {
|
||||
params := L.CheckTable(3)
|
||||
params.ForEach(func(k lua.LValue, v lua.LValue) {
|
||||
args = append(args, ConvertLuaTypesToGolang(v))
|
||||
})
|
||||
}
|
||||
|
||||
if conn.log {
|
||||
conn.logger.Info("DB QueryRow",
|
||||
slog.String("query", query),
|
||||
slog.Any("params", args))
|
||||
}
|
||||
|
||||
mtx := getDBMutex(conn.dbPath)
|
||||
mtx.RLock()
|
||||
defer mtx.RUnlock()
|
||||
|
||||
db, err := sql.Open("sqlite", conn.dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_sync=NORMAL&_cache_size=-10000")
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
row := db.QueryRow(query, args...)
|
||||
|
||||
columns := []string{}
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("prepare failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.Query(args...)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("query failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
defer rows.Close()
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("get columns failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
for _, c := range cols {
|
||||
columns = append(columns, c)
|
||||
}
|
||||
|
||||
colCount := len(columns)
|
||||
values := make([]any, colCount)
|
||||
valuePtrs := make([]any, colCount)
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
err = row.Scan(valuePtrs...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
L.Push(lua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("scan failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
rowTable := L.NewTable()
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if val == nil {
|
||||
L.SetField(rowTable, col, lua.LNil)
|
||||
} else {
|
||||
L.SetField(rowTable, col, ConvertGolangTypesToLua(L, val))
|
||||
}
|
||||
}
|
||||
|
||||
L.Push(rowTable)
|
||||
return 1
|
||||
}
|
||||
|
||||
func dbQuery(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
|
||||
@@ -43,6 +43,8 @@ func addInitiatorHeaders(sid string, req *http.Request, headers http.Header) {
|
||||
// I will be only glad.
|
||||
// TODO: make this huge function more harmonious by dividing responsibilities
|
||||
func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse {
|
||||
var __exit = -1
|
||||
|
||||
llog := h.x.SLog.With(slog.String("session-id", sid))
|
||||
llog.Debug("handling LUA")
|
||||
L := lua.NewState()
|
||||
@@ -101,10 +103,13 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
return 1
|
||||
})
|
||||
|
||||
L.SetField(headersTable, "__fetched", fetchedHeadersTable)
|
||||
|
||||
L.SetField(headersTable, "get", headersGetter)
|
||||
L.SetField(inTable, "headers", headersTable)
|
||||
|
||||
fetchedParamsTable := L.NewTable()
|
||||
switch params := req.Params.(type) {
|
||||
case nil:
|
||||
fetchedParamsTable.RawSetInt(1, lua.LNil)
|
||||
case map[string]any:
|
||||
for k, v := range params {
|
||||
L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v))
|
||||
@@ -113,19 +118,13 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
for i, v := range params {
|
||||
fetchedParamsTable.RawSetInt(i+1, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
default:
|
||||
fetchedParamsTable.RawSetInt(1, lua.LNil)
|
||||
}
|
||||
|
||||
|
||||
paramsGetter := L.NewFunction(func(L *lua.LState) int {
|
||||
path := L.OptString(1, "")
|
||||
def := L.Get(2)
|
||||
|
||||
get := func(tbl *lua.LTable, path string) lua.LValue {
|
||||
if tbl.RawGetInt(1) == lua.LNil {
|
||||
return lua.LNil
|
||||
}
|
||||
if path == "" {
|
||||
return tbl
|
||||
}
|
||||
@@ -146,7 +145,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
return current
|
||||
}
|
||||
|
||||
val := get(fetchedParamsTable, path)
|
||||
paramsTbl := L.GetField(paramsTable, "__fetched") //
|
||||
val := get(paramsTbl.(*lua.LTable), path) //
|
||||
if val == lua.LNil && def != lua.LNil {
|
||||
L.Push(def)
|
||||
} else {
|
||||
@@ -154,19 +154,132 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
}
|
||||
return 1
|
||||
})
|
||||
L.SetField(headersTable, "__fetched", fetchedHeadersTable)
|
||||
|
||||
L.SetField(headersTable, "get", headersGetter)
|
||||
L.SetField(inTable, "headers", headersTable)
|
||||
|
||||
L.SetField(paramsTable, "__fetched", fetchedParamsTable)
|
||||
|
||||
L.SetField(paramsTable, "get", paramsGetter)
|
||||
L.SetField(inTable, "params", paramsTable)
|
||||
|
||||
outTable := L.NewTable()
|
||||
scriptDataTable := L.NewTable()
|
||||
L.SetField(outTable, "__script_data", scriptDataTable)
|
||||
|
||||
L.SetField(inTable, "address", lua.LString(r.RemoteAddr))
|
||||
|
||||
L.SetField(sessionMod, "throw_error", L.NewFunction(func(L *lua.LState) int {
|
||||
arg := L.Get(1)
|
||||
var msg string
|
||||
switch arg.Type() {
|
||||
case lua.LTString:
|
||||
msg = arg.String()
|
||||
case lua.LTNumber:
|
||||
msg = strconv.FormatFloat(float64(arg.(lua.LNumber)), 'f', -1, 64)
|
||||
default:
|
||||
L.ArgError(1, "expected string or number")
|
||||
return 0
|
||||
}
|
||||
|
||||
L.RaiseError("%s", msg)
|
||||
return 0
|
||||
}))
|
||||
|
||||
resTable := L.NewTable()
|
||||
L.SetField(scriptDataTable, "result", resTable)
|
||||
L.SetField(outTable, "send", L.NewFunction(func(L *lua.LState) int {
|
||||
res := L.Get(1)
|
||||
if res == lua.LNil {
|
||||
__exit = 0
|
||||
L.RaiseError("__successfull")
|
||||
return 0
|
||||
}
|
||||
|
||||
resFTable := scriptDataTable.RawGetString("result")
|
||||
if resPTable, ok := res.(*lua.LTable); ok {
|
||||
resPTable.ForEach(func(key, value lua.LValue) {
|
||||
L.SetField(resFTable, key.String(), value)
|
||||
})
|
||||
} else {
|
||||
L.SetField(scriptDataTable, "result", res)
|
||||
}
|
||||
|
||||
__exit = 0
|
||||
L.RaiseError("__successfull")
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(outTable, "set", L.NewFunction(func(L *lua.LState) int {
|
||||
res := L.Get(1)
|
||||
if res == lua.LNil {
|
||||
return 0
|
||||
}
|
||||
|
||||
resFTable := scriptDataTable.RawGetString("result")
|
||||
if resPTable, ok := res.(*lua.LTable); ok {
|
||||
resPTable.ForEach(func(key, value lua.LValue) {
|
||||
L.SetField(resFTable, key.String(), value)
|
||||
})
|
||||
} else {
|
||||
L.SetField(scriptDataTable, "result", res)
|
||||
}
|
||||
return 0
|
||||
}))
|
||||
|
||||
errTable := L.NewTable()
|
||||
L.SetField(scriptDataTable, "error", errTable)
|
||||
L.SetField(outTable, "send_error", L.NewFunction(func(L *lua.LState) int {
|
||||
var params [3]lua.LValue
|
||||
for i := range 3 {
|
||||
params[i] = L.Get(i + 1)
|
||||
}
|
||||
if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
for _, v := range params {
|
||||
switch v.Type() {
|
||||
case lua.LTNumber:
|
||||
if n, ok := v.(lua.LNumber); ok {
|
||||
L.SetField(errTable, "code", n)
|
||||
}
|
||||
case lua.LTString:
|
||||
if s, ok := v.(lua.LString); ok {
|
||||
L.SetField(errTable, "message", s)
|
||||
}
|
||||
case lua.LTTable:
|
||||
if tbl, ok := v.(*lua.LTable); ok {
|
||||
L.SetField(errTable, "data", tbl)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__exit = 1
|
||||
L.RaiseError("__unsuccessfull")
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(outTable, "set_error", L.NewFunction(func(L *lua.LState) int {
|
||||
var params [3]lua.LValue
|
||||
for i := range 3 {
|
||||
params[i] = L.Get(i + 1)
|
||||
}
|
||||
if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
for _, v := range params {
|
||||
switch v.Type() {
|
||||
case lua.LTNumber:
|
||||
if n, ok := v.(lua.LNumber); ok {
|
||||
L.SetField(errTable, "code", n)
|
||||
}
|
||||
case lua.LTString:
|
||||
if s, ok := v.(lua.LString); ok {
|
||||
L.SetField(errTable, "message", s)
|
||||
}
|
||||
case lua.LTTable:
|
||||
if tbl, ok := v.(*lua.LTable); ok {
|
||||
L.SetField(errTable, "data", tbl)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(sessionMod, "request", inTable)
|
||||
L.SetField(sessionMod, "response", outTable)
|
||||
|
||||
@@ -431,12 +544,7 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
|
||||
L.SetField(sha265mod, "sum", L.NewFunction(func(l *lua.LState) int {
|
||||
data := ConvertLuaTypesToGolang(L.Get(1))
|
||||
dataStr, ok := data.(string)
|
||||
if !ok {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("error: data must be a string"))
|
||||
return 2
|
||||
}
|
||||
var dataStr = fmt.Sprint(data)
|
||||
|
||||
hash := sha256.Sum256([]byte(dataStr))
|
||||
|
||||
@@ -467,7 +575,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
}
|
||||
}
|
||||
llog.Debug("executing script", slog.String("script", path))
|
||||
if err := L.DoFile(path); err != nil {
|
||||
err := L.DoFile(path)
|
||||
if err != nil && __exit != 0 && __exit != 1 {
|
||||
llog.Error("script error", slog.String("script", path), slog.String("error", err.Error()))
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
@@ -505,34 +614,28 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
|
||||
if errVal := outTbl.RawGetString("error"); errVal != lua.LNil {
|
||||
if scriptDataTable, ok := outTbl.RawGetString("__script_data").(*lua.LTable); ok {
|
||||
switch __exit {
|
||||
case 1:
|
||||
if errTbl, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
llog.Debug("catch error table", slog.String("script", path))
|
||||
if errTbl, ok := errVal.(*lua.LTable); ok {
|
||||
code := rpc.ErrInternalError
|
||||
message := rpc.ErrInternalErrorS
|
||||
data := make(map[string]any)
|
||||
if c := errTbl.RawGetString("code"); c.Type() == lua.LTNumber {
|
||||
code = int(c.(lua.LNumber))
|
||||
}
|
||||
if msg := errTbl.RawGetString("message"); msg.Type() == lua.LTString {
|
||||
message = msg.String()
|
||||
}
|
||||
rawData := errTbl.RawGetString("data")
|
||||
|
||||
if tbl, ok := rawData.(*lua.LTable); ok {
|
||||
tbl.ForEach(func(k, v lua.LValue) { data[k.String()] = ConvertLuaTypesToGolang(v) })
|
||||
} else {
|
||||
llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message))
|
||||
return rpc.NewError(code, message, ConvertLuaTypesToGolang(rawData), req.ID)
|
||||
}
|
||||
llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message))
|
||||
data := ConvertLuaTypesToGolang(errTbl.RawGetString("data"))
|
||||
llog.Error("the script terminated with an error", slog.Int("code", code), slog.String("message", message), slog.Any("data", data))
|
||||
return rpc.NewError(code, message, data, req.ID)
|
||||
}
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
case 0:
|
||||
resVal := ConvertLuaTypesToGolang(scriptDataTable.RawGetString("result"))
|
||||
return rpc.NewResponse(resVal, req.ID)
|
||||
}
|
||||
|
||||
if resultVal := outTbl.RawGetString("result"); resultVal != lua.LNil {
|
||||
return rpc.NewResponse(ConvertLuaTypesToGolang(resultVal), req.ID)
|
||||
}
|
||||
return rpc.NewResponse(nil, req.ID)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package sv1
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
@@ -18,19 +19,56 @@ func ConvertLuaTypesToGolang(value lua.LValue) any {
|
||||
case lua.LTTable:
|
||||
tbl := value.(*lua.LTable)
|
||||
|
||||
var arr []any
|
||||
maxIdx := 0
|
||||
isArray := true
|
||||
tbl.ForEach(func(key, val lua.LValue) {
|
||||
if key.Type() != lua.LTNumber {
|
||||
|
||||
var isNumeric = false
|
||||
tbl.ForEach(func(key, _ lua.LValue) {
|
||||
var numKey lua.LValue
|
||||
var ok bool
|
||||
switch key.Type() {
|
||||
case lua.LTString:
|
||||
numKey, ok = key.(lua.LString)
|
||||
if !ok {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
case lua.LTNumber:
|
||||
numKey, ok = key.(lua.LNumber)
|
||||
if !ok {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
isNumeric = true
|
||||
}
|
||||
|
||||
num, err := strconv.Atoi(numKey.String())
|
||||
if err != nil {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
if num < 1 {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
if num > maxIdx {
|
||||
maxIdx = num
|
||||
}
|
||||
arr = append(arr, ConvertLuaTypesToGolang(val))
|
||||
})
|
||||
|
||||
if isArray {
|
||||
arr := make([]any, maxIdx)
|
||||
if isNumeric {
|
||||
for i := 1; i <= maxIdx; i++ {
|
||||
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetInt(i))
|
||||
}
|
||||
} else {
|
||||
for i := 1; i <= maxIdx; i++ {
|
||||
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetString(strconv.Itoa(i)))
|
||||
}
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
tbl.ForEach(func(key, val lua.LValue) {
|
||||
result[key.String()] = ConvertLuaTypesToGolang(val)
|
||||
|
||||
Reference in New Issue
Block a user