diff --git a/internal/server/sv1/lua_handler.go b/internal/server/sv1/lua_handler.go index 7963f57..7ac18c2 100644 --- a/internal/server/sv1/lua_handler.go +++ b/internal/server/sv1/lua_handler.go @@ -3,7 +3,6 @@ package sv1 // TODO: make a lua state pool using sync.Pool import ( - "database/sql" "fmt" "io" "log/slog" @@ -13,7 +12,8 @@ import ( "path/filepath" "strconv" "strings" - "sync" + + "golang.org/x/crypto/bcrypt" "github.com/akyaiy/GoSally-mvp/internal/colors" "github.com/akyaiy/GoSally-mvp/internal/server/rpc" @@ -21,314 +21,6 @@ import ( _ "modernc.org/sqlite" ) -type DBConnection struct { - dbPath string - log bool - logger *slog.Logger - writeChan chan *dbWriteRequest - closeChan chan struct{} -} - -type dbWriteRequest struct { - query string - args []interface{} - resCh chan *dbWriteResult -} - -type dbWriteResult struct { - rowsAffected int64 - err error -} - -var dbMutexMap = make(map[string]*sync.RWMutex) -var dbGlobalMutex sync.Mutex - -func getDBMutex(dbPath string) *sync.RWMutex { - dbGlobalMutex.Lock() - defer dbGlobalMutex.Unlock() - - if mtx, ok := dbMutexMap[dbPath]; ok { - return mtx - } - - mtx := &sync.RWMutex{} - dbMutexMap[dbPath] = mtx - return mtx -} - -func loadDBMod(llog *slog.Logger) func(*lua.LState) int { - llog.Debug("import module db-sqlite") - return func(L *lua.LState) int { - dbMod := L.NewTable() - - L.SetField(dbMod, "connect", L.NewFunction(func(L *lua.LState) int { - dbPath := L.CheckString(1) - - logQueries := false - if L.GetTop() >= 2 { - opts := L.CheckTable(2) - if val := opts.RawGetString("log"); val != lua.LNil { - logQueries = lua.LVAsBool(val) - } - } - - conn := &DBConnection{ - dbPath: dbPath, - log: logQueries, - logger: llog, - writeChan: make(chan *dbWriteRequest, 100), - closeChan: make(chan struct{}), - } - - go conn.processWrites() - - ud := L.NewUserData() - ud.Value = conn - L.SetMetatable(ud, L.GetTypeMetatable("gosally_db")) - - L.Push(ud) - return 1 - })) - - mt := L.NewTypeMetatable("gosally_db") - L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{ - "exec": dbExec, - "query": dbQuery, - "close": dbClose, - })) - - L.SetField(dbMod, "__gosally_internal", lua.LString("0")) - L.Push(dbMod) - return 1 - } -} - -func (conn *DBConnection) processWrites() { - for { - select { - case req := <-conn.writeChan: - mtx := getDBMutex(conn.dbPath) - mtx.Lock() - - db, err := sql.Open("sqlite", conn.dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_sync=NORMAL&_cache_size=-10000") - if err == nil { - _, err = db.Exec("PRAGMA journal_mode=WAL;") - if err == nil { - res, execErr := db.Exec(req.query, req.args...) - if execErr == nil { - rows, _ := res.RowsAffected() - req.resCh <- &dbWriteResult{rowsAffected: rows} - } else { - req.resCh <- &dbWriteResult{err: execErr} - } - } - db.Close() - } - - if err != nil { - req.resCh <- &dbWriteResult{err: err} - } - - mtx.Unlock() - case <-conn.closeChan: - return - } - } -} - -func dbExec(L *lua.LState) int { - ud := L.CheckUserData(1) - conn, ok := ud.Value.(*DBConnection) - if !ok { - L.Push(lua.LNil) - L.Push(lua.LString("invalid database connection")) - return 2 - } - - query := L.CheckString(2) - - var args []any - if L.GetTop() >= 3 { - params := L.CheckTable(3) - params.ForEach(func(k lua.LValue, v lua.LValue) { - args = append(args, ConvertLuaTypesToGolang(v)) - }) - } - - if conn.log { - conn.logger.Info("DB Exec", - slog.String("query", query), - slog.Any("params", args)) - } - - resCh := make(chan *dbWriteResult, 1) - conn.writeChan <- &dbWriteRequest{ - query: query, - args: args, - resCh: resCh, - } - - ctx := L.NewTable() - L.SetField(ctx, "done", lua.LBool(false)) - - var result lua.LValue = lua.LNil - var errorMsg lua.LValue = lua.LNil - - L.SetField(ctx, "wait", L.NewFunction(func(lL *lua.LState) int { - res := <-resCh - L.SetField(ctx, "done", lua.LBool(true)) - - if res.err != nil { - errorMsg = lua.LString(res.err.Error()) - result = lua.LNil - } else { - result = lua.LNumber(res.rowsAffected) - errorMsg = lua.LNil - } - - if res.err != nil { - lL.Push(lua.LNil) - lL.Push(lua.LString(res.err.Error())) - return 2 - } - lL.Push(lua.LNumber(res.rowsAffected)) - lL.Push(lua.LNil) - return 2 - })) - - L.SetField(ctx, "check", L.NewFunction(func(lL *lua.LState) int { - select { - case res := <-resCh: - lL.SetField(ctx, "done", lua.LBool(true)) - if res.err != nil { - errorMsg = lua.LString(res.err.Error()) - result = lua.LNil - lL.Push(lua.LNil) - lL.Push(lua.LString(res.err.Error())) - return 2 - } else { - result = lua.LNumber(res.rowsAffected) - errorMsg = lua.LNil - lL.Push(lua.LNumber(res.rowsAffected)) - lL.Push(lua.LNil) - return 2 - } - default: - lL.Push(result) - lL.Push(errorMsg) - return 2 - } - })) - - L.Push(ctx) - L.Push(lua.LNil) - return 2 -} - -func dbQuery(L *lua.LState) int { - ud := L.CheckUserData(1) - conn, ok := ud.Value.(*DBConnection) - if !ok { - L.Push(lua.LNil) - L.Push(lua.LString("invalid database connection")) - return 2 - } - - query := L.CheckString(2) - - var args []any - if L.GetTop() >= 3 { - params := L.CheckTable(3) - params.ForEach(func(k lua.LValue, v lua.LValue) { - args = append(args, ConvertLuaTypesToGolang(v)) - }) - } - - if conn.log { - conn.logger.Info("DB Query", - slog.String("query", query), - slog.Any("params", args)) - } - - mtx := getDBMutex(conn.dbPath) - mtx.RLock() - defer mtx.RUnlock() - - db, err := sql.Open("sqlite", conn.dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_sync=NORMAL&_cache_size=-10000") - if err != nil { - L.Push(lua.LNil) - L.Push(lua.LString(err.Error())) - return 2 - } - defer db.Close() - - rows, err := db.Query(query, args...) - if err != nil { - L.Push(lua.LNil) - L.Push(lua.LString(fmt.Sprintf("query failed: %v", err))) - return 2 - } - defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - L.Push(lua.LNil) - L.Push(lua.LString(fmt.Sprintf("get columns failed: %v", err))) - return 2 - } - - result := L.NewTable() - colCount := len(columns) - values := make([]any, colCount) - valuePtrs := make([]any, colCount) - - for rows.Next() { - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - L.Push(lua.LNil) - L.Push(lua.LString(fmt.Sprintf("scan failed: %v", err))) - return 2 - } - - rowTable := L.NewTable() - for i, col := range columns { - val := values[i] - if val == nil { - L.SetField(rowTable, col, lua.LNil) - } else { - L.SetField(rowTable, col, ConvertGolangTypesToLua(L, val)) - } - } - result.Append(rowTable) - } - - if err := rows.Err(); err != nil { - L.Push(lua.LNil) - L.Push(lua.LString(fmt.Sprintf("rows iteration failed: %v", err))) - return 2 - } - - L.Push(result) - return 1 -} - -func dbClose(L *lua.LState) int { - ud := L.CheckUserData(1) - conn, ok := ud.Value.(*DBConnection) - if !ok { - L.Push(lua.LFalse) - L.Push(lua.LString("invalid database connection")) - return 2 - } - - close(conn.closeChan) - L.Push(lua.LTrue) - return 1 -} - func addInitiatorHeaders(sid string, req *http.Request, headers http.Header) { clientIP := req.RemoteAddr if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" { @@ -559,10 +251,83 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, return 1 } + loadCryptbcryptMod := func(lL *lua.LState) int { + llog.Debug("import module crypt.bcrypt", slog.String("script", path)) + bcryptMod := lL.NewTable() + + lL.SetField(bcryptMod, "MinCost", lua.LNumber(bcrypt.MinCost)) + lL.SetField(bcryptMod, "MaxCost", lua.LNumber(bcrypt.MaxCost)) + lL.SetField(bcryptMod, "DefaultCost", lua.LNumber(bcrypt.DefaultCost)) + + lL.SetField(bcryptMod, "generate", lL.NewFunction(func(l *lua.LState) int { + password := ConvertLuaTypesToGolang(lL.Get(1)) + passwordStr, ok := password.(string) + if !ok { + lL.Push(lua.LNil) + lL.Push(lua.LString("error: password must be a string")) + return 2 + } + + cost := ConvertLuaTypesToGolang(lL.Get(2)) + costInt := bcrypt.DefaultCost + switch v := cost.(type) { + case int: + costInt = v + case float64: + costInt = int(v) + case nil: + // ok, use DefaultCost + default: + lL.Push(lua.LNil) + lL.Push(lua.LString("error: cost must be an integer")) + return 2 + } + + hashBytes, err := bcrypt.GenerateFromPassword([]byte(passwordStr), costInt) + if err != nil { + lL.Push(lua.LNil) + lL.Push(lua.LString("error: " + err.Error())) + return 2 + } + + lL.Push(lua.LString(string(hashBytes))) + lL.Push(lua.LNil) + return 2 + })) + + lL.SetField(bcryptMod, "compare", lL.NewFunction(func(l *lua.LState) int { + hash := ConvertLuaTypesToGolang(lL.Get(1)) + hashStr, ok := hash.(string) + if !ok { + lL.Push(lua.LString("error: hash must be a string")) + return 1 + } + password := ConvertLuaTypesToGolang(lL.Get(2)) + passwordStr, ok := password.(string) + if !ok { + lL.Push(lua.LString("error: password must be a string")) + return 1 + } + + err := bcrypt.CompareHashAndPassword([]byte(hashStr), []byte(passwordStr)) + if err != nil { + lL.Push(lua.LFalse) + return 1 + } + lL.Push(lua.LTrue) + return 1 + })) + + lL.SetField(bcryptMod, "__gosally_internal", lua.LString(fmt.Sprint(seed))) + lL.Push(bcryptMod) + return 1 + } + L.PreloadModule("internal.session", loadSessionMod) L.PreloadModule("internal.log", loadLogMod) L.PreloadModule("internal.net", loadNetMod) L.PreloadModule("internal.database-sqlite", loadDBMod(llog)) + L.PreloadModule("internal.crypt.bcrypt", loadCryptbcryptMod) llog.Debug("preparing environment") prep := filepath.Join(*h.x.Config.Conf.Node.ComDir, "_prepare.lua")