mirror of
https://github.com/akyaiy/GoSally-mvp.git
synced 2026-01-03 19:52:25 +00:00
Compare commits
44 Commits
5b32698ec5
...
auth-serve
| Author | SHA1 | Date | |
|---|---|---|---|
| 625e5daf71 | |||
| cc27843bb3 | |||
| 20fec82159 | |||
| 055b299ecb | |||
| 17bf207087 | |||
| 7ae8e12dc8 | |||
| 6e36db428a | |||
| 06103a3264 | |||
| c6da55ad65 | |||
| 20a1e3e7bb | |||
| e594d519a7 | |||
| 2ceb236a53 | |||
| 811403a0a2 | |||
| b451f2d3fc | |||
| 5c01eaad6f | |||
| 2b38e179db | |||
| 2889092821 | |||
| 3df3a7b4b5 | |||
| c63f1bd123 | |||
| 095b8559f4 | |||
| 39532f22ea | |||
| 35cebee819 | |||
| 84dfdd6b35 | |||
| e693efe8e7 | |||
| c3dcf24e50 | |||
| 9e7d99e854 | |||
| 7f2783b39a | |||
| c08135309f | |||
| cd9e3ab6c4 | |||
| adaedf195f | |||
| 87694f6654 | |||
| fe628e0f7f | |||
| 3898e2833b | |||
| e4db8505a0 | |||
| 0c25d00171 | |||
| b5a6de0b62 | |||
| 1d3d74846e | |||
| 0141427bfe | |||
| 866946646b | |||
| 251e580e8a | |||
| c734779b69 | |||
| 0923f32b46 | |||
| 1c2c4c1356 | |||
| d3eb483461 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@ tmp/
|
||||
db/
|
||||
|
||||
com/test.lua
|
||||
com/_config.lua
|
||||
|
||||
.vscode
|
||||
Taskfile.yml
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// The cmd package is the main package where all the main hooks and methods are called.
|
||||
// GoSally uses spf13/cobra to organize all the calls.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
@@ -22,6 +24,8 @@ scripts in a given directory. For more information, visit: https://gosally.oblat
|
||||
},
|
||||
}
|
||||
|
||||
// Execute prepares global log, loads cmdline args
|
||||
// and executes rootCmd.Execute()
|
||||
func Execute() {
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetPrefix(colors.SetBrightBlack(fmt.Sprintf("(%s) ", corestate.StageNotReady)))
|
||||
|
||||
@@ -11,6 +11,7 @@ var runCmd = &cobra.Command{
|
||||
Short: "Run node normally",
|
||||
Long: `
|
||||
"run" starts the node with settings depending on the configuration file`,
|
||||
// hooks.Run essentially the heart of the program
|
||||
Run: hooks.Run,
|
||||
}
|
||||
|
||||
|
||||
77
com/Access/GetMasterAccess.lua
Normal file
77
com/Access/GetMasterAccess.lua
Normal file
@@ -0,0 +1,77 @@
|
||||
local session = require("internal.session")
|
||||
local log = require("internal.log")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
local bc = require("internal.crypt.bcrypt")
|
||||
local db = require("internal.database.sqlite").connect("db/root.db", {log = true})
|
||||
local sha256 = require("internal.crypt.sha256")
|
||||
|
||||
log.info("Someone at "..session.request.address.." trying to get master access")
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
db:close()
|
||||
db = nil
|
||||
end
|
||||
end
|
||||
|
||||
local params = session.request.params.get()
|
||||
|
||||
local function check_missing(arr, p)
|
||||
local is_missing = {}
|
||||
local ok = true
|
||||
for _, key in ipairs(arr) do
|
||||
if p[key] == nil then
|
||||
table.insert(is_missing, key)
|
||||
ok = false
|
||||
end
|
||||
end
|
||||
return ok, is_missing
|
||||
end
|
||||
|
||||
local ok, mp = check_missing({"master_secret", "master_name", "my_key"}, params)
|
||||
if not ok then
|
||||
close_db()
|
||||
session.response.send_error(-32602, "Missing params", mp)
|
||||
end
|
||||
|
||||
if type(params.master_secret) ~= "string" then
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
if type(params.master_name) ~= "string" then
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local master, err = db:query_row("SELECT * FROM master_units WHERE master_name = ?", {params.master_name})
|
||||
|
||||
if not master then
|
||||
log.event("DB query failed:", err)
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local ok = bc.compare(master.master_secret, params.master_secret)
|
||||
if not ok then
|
||||
log.warn("Login failed: wrong password")
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local token = jwt.encode({
|
||||
secret = require("_config").token(),
|
||||
payload = {
|
||||
session_uuid = session.id,
|
||||
master_id = master.id,
|
||||
key = sha256.sum(params.my_key)
|
||||
},
|
||||
expires_in = 3600
|
||||
})
|
||||
|
||||
close_db()
|
||||
session.response.send({
|
||||
token = token
|
||||
})
|
||||
|
||||
-- G7HgOgl72o7t7u7r
|
||||
@@ -1,77 +0,0 @@
|
||||
---@diagnostic disable: redefined-local
|
||||
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true})
|
||||
local log = require("internal.log")
|
||||
local session = require("internal.session")
|
||||
local crypt = require("internal.crypt.bcrypt")
|
||||
|
||||
if not session.request.params then
|
||||
session.response.error = {
|
||||
message = "no params provided"
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
local params = session.request.params
|
||||
|
||||
if not (params.username and params.email and params.password) then
|
||||
session.response.error = {
|
||||
message = "no username/email/password provided"
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
local hashPass = crypt.generate(params.password, crypt.DefaultCost)
|
||||
|
||||
local existing, err = db:query("SELECT 1 FROM users WHERE email = ? OR username = ? LIMIT 1", {
|
||||
params.email,
|
||||
params.username
|
||||
})
|
||||
if err ~= nil then
|
||||
session.response.error = {
|
||||
message = "Database check failed: "..tostring(err)
|
||||
}
|
||||
log.error("Email check failed: "..tostring(err))
|
||||
return
|
||||
end
|
||||
|
||||
if existing and #existing > 0 then
|
||||
session.response.error = {
|
||||
code = -32604,
|
||||
message = "Unit already exists"
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
local ctx, err = db:exec(
|
||||
"INSERT INTO users (username, email, password) VALUES (?, ?, ?)",
|
||||
{
|
||||
params.username,
|
||||
params.email,
|
||||
hashPass
|
||||
}
|
||||
)
|
||||
if err ~= nil then
|
||||
session.response.error = {
|
||||
code = -32605,
|
||||
message = "Insert failed: "..tostring(err)
|
||||
}
|
||||
log.error("Insert failed: "..tostring(err))
|
||||
return
|
||||
end
|
||||
|
||||
local res, err = ctx:wait()
|
||||
if err ~= nil then
|
||||
session.response.error = {
|
||||
code = -32606,
|
||||
message = "Insert confirmation failed: "..tostring(err)
|
||||
}
|
||||
log.error("Insert confirmation failed: "..tostring(err))
|
||||
return
|
||||
end
|
||||
|
||||
session.response.result = {
|
||||
rows_affected = res,
|
||||
message = "Unit created successfully"
|
||||
}
|
||||
|
||||
db:close()
|
||||
11
com/Echo.lua
Normal file
11
com/Echo.lua
Normal file
@@ -0,0 +1,11 @@
|
||||
local s = require("internal.session")
|
||||
|
||||
if not s.request.params.__fetched.data then
|
||||
s.response.error = {
|
||||
code = 123,
|
||||
message = "params.data is missing"
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
s.response.send(s.request.params.__fetched)
|
||||
10
com/List.lua
10
com/List.lua
@@ -2,7 +2,9 @@
|
||||
|
||||
local session = require("internal.session")
|
||||
|
||||
if session.request.params.about then
|
||||
local params = session.request.params.get()
|
||||
|
||||
if params.about then
|
||||
session.response.result = {
|
||||
description = "Returns a list of available methods",
|
||||
params = {
|
||||
@@ -48,8 +50,8 @@ local function scanDirectory(basePath, targetPath)
|
||||
end
|
||||
|
||||
local basePath = "com"
|
||||
local layer = session.request and session.request.params.layer and session.request.params.layer:gsub(">", "/") or nil
|
||||
local layer = params.layer and params.layer:gsub(">", "/") or nil
|
||||
|
||||
session.response.result = {
|
||||
session.response.send({
|
||||
answer = layer and scanDirectory(basePath, layer) or scanDirectory(basePath, "")
|
||||
}
|
||||
})
|
||||
69
com/Zones/GetZoneInfo.lua
Normal file
69
com/Zones/GetZoneInfo.lua
Normal file
@@ -0,0 +1,69 @@
|
||||
local session = require("internal.session")
|
||||
local log = require("internal.log")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
local bc = require("internal.crypt.bcrypt")
|
||||
local sha256 = require("internal.crypt.sha256")
|
||||
local dbdriver = require("internal.database.sqlite")
|
||||
|
||||
local db_root = dbdriver.connect("db/root.db", {log = true})
|
||||
local db_zone = nil
|
||||
|
||||
local function close_db()
|
||||
if db_root then
|
||||
db_root:close()
|
||||
db_root = nil
|
||||
end
|
||||
if db_zone then
|
||||
db_zone:close()
|
||||
db_zone = nil
|
||||
end
|
||||
end
|
||||
|
||||
local token = session.request.headers.get("authorization")
|
||||
|
||||
if not token or type(token) ~= "string" then
|
||||
close_db()
|
||||
session.response.send_error(-32050, "Access denied")
|
||||
end
|
||||
|
||||
local prefix = "Bearer "
|
||||
if token:sub(1, #prefix) ~= prefix then
|
||||
close_db()
|
||||
session.response.send_error(-32052, "Invalid Authorization scheme")
|
||||
end
|
||||
|
||||
local access_token = token:sub(#prefix + 1)
|
||||
|
||||
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
|
||||
|
||||
if err or not data then
|
||||
close_db()
|
||||
session.response.send_error(-32053, "Cannod parse JWT", {err})
|
||||
end
|
||||
|
||||
if data.master_id then
|
||||
|
||||
|
||||
end
|
||||
|
||||
local params = session.request.params.get()
|
||||
|
||||
local function check_missing(arr, p)
|
||||
local is_missing = {}
|
||||
local ok = true
|
||||
for _, key in ipairs(arr) do
|
||||
if p[key] == nil then
|
||||
table.insert(is_missing, key)
|
||||
ok = false
|
||||
end
|
||||
end
|
||||
return ok, is_missing
|
||||
end
|
||||
|
||||
local ok, mp = check_missing({"zone_name"}, params)
|
||||
if not ok then
|
||||
close_db()
|
||||
session.response.send_error(-32602, "Missing params", mp)
|
||||
end
|
||||
|
||||
close_db()
|
||||
119
com/_Auth/DeleteUnit.lua
Normal file
119
com/_Auth/DeleteUnit.lua
Normal file
@@ -0,0 +1,119 @@
|
||||
-- com/DeleteUnit.lua
|
||||
|
||||
---@diagnostic disable: redefined-local
|
||||
local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
|
||||
local log = require("internal.log")
|
||||
local session = require("internal.session")
|
||||
local crypt = require("internal.crypt.bcrypt")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
local sha256 = require("internal.crypt.sha256")
|
||||
|
||||
local params = session.request.params.get()
|
||||
local token = session.request.headers.get("authorization")
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
db:close()
|
||||
db = nil
|
||||
end
|
||||
end
|
||||
|
||||
local function error_response(message, code, data)
|
||||
session.response.error = {
|
||||
code = code or nil,
|
||||
message = message,
|
||||
data = data or nil
|
||||
}
|
||||
close_db()
|
||||
end
|
||||
|
||||
if not token or type(token) ~= "string" then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
local prefix = "Bearer "
|
||||
if token:sub(1, #prefix) ~= prefix then
|
||||
return error_response("Invalid Authorization scheme")
|
||||
end
|
||||
|
||||
local access_token = token:sub(#prefix + 1)
|
||||
|
||||
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
|
||||
|
||||
if err or not data then
|
||||
session.response.error = {
|
||||
message = err
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
-- if data.session_uuid ~= session.id then
|
||||
-- return error_response("Access denied")
|
||||
-- end
|
||||
|
||||
-- if data.key ~= sha256.sum(session.request.address .. session.id .. session.request.headers.get("user-agent", "noagent")) then
|
||||
-- return error_response("Access denied")
|
||||
-- end
|
||||
|
||||
if not params then
|
||||
return error_response("no params provided")
|
||||
end
|
||||
|
||||
if not (params.username and params.email and params.password) then
|
||||
return error_response("no username/email/password provided")
|
||||
end
|
||||
|
||||
local existing, err = db:query(
|
||||
"SELECT password FROM users WHERE email = ? AND username = ? AND deleted = 0 LIMIT 1",
|
||||
{
|
||||
params.email,
|
||||
params.username
|
||||
}
|
||||
)
|
||||
|
||||
if err ~= nil then
|
||||
log.error("Password fetch failed: " .. tostring(err))
|
||||
return error_response("Database query failed: " .. tostring(err))
|
||||
end
|
||||
|
||||
if not existing or #existing == 0 then
|
||||
return error_response("Unit not found")
|
||||
end
|
||||
|
||||
local hashed_password = existing[1].password
|
||||
|
||||
local ok = crypt.compare(hashed_password, params.password)
|
||||
if not ok then
|
||||
log.warn("Wrong password attempt for: " .. params.username)
|
||||
return error_response("Invalid password")
|
||||
end
|
||||
|
||||
local ctx, err = db:exec(
|
||||
[[
|
||||
UPDATE users
|
||||
SET deleted = 1,
|
||||
deleted_at = CURRENT_TIMESTAMP
|
||||
WHERE email = ? AND username = ? AND deleted = 0
|
||||
]],
|
||||
{ params.email, params.username }
|
||||
)
|
||||
|
||||
if err ~= nil then
|
||||
log.error("Soft delete failed: " .. tostring(err))
|
||||
return error_response("Soft delete failed: " .. tostring(err))
|
||||
end
|
||||
|
||||
local res, err = ctx:wait()
|
||||
if err ~= nil then
|
||||
log.error("Soft delete confirmation failed: " .. tostring(err))
|
||||
return error_response("Soft delete confirmation failed: " .. tostring(err))
|
||||
end
|
||||
|
||||
session.response.result = {
|
||||
rows_affected = res,
|
||||
message = "Unit soft-deleted successfully"
|
||||
}
|
||||
|
||||
log.info("user " .. params.username .. " soft-deleted successfully")
|
||||
|
||||
close_db()
|
||||
@@ -1,8 +1,15 @@
|
||||
-- com/GetAccess
|
||||
|
||||
---@diagnostic disable: redefined-local
|
||||
local db = require("internal.database-sqlite").connect("db/user-database.db", {log = true})
|
||||
local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
|
||||
local log = require("internal.log")
|
||||
local session = require("internal.session")
|
||||
local crypt = require("internal.crypt.bcrypt")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
local sha256 = require("internal.crypt.sha256")
|
||||
|
||||
local params = session.request.params.get()
|
||||
local secret = require("_config").token()
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
@@ -13,25 +20,27 @@ end
|
||||
|
||||
local function error_response(message, code, data)
|
||||
session.response.error = {
|
||||
code = code or -32600,
|
||||
code = code or nil,
|
||||
message = message,
|
||||
data = data or nil
|
||||
}
|
||||
close_db()
|
||||
end
|
||||
|
||||
local params = session.request.params
|
||||
if not params then
|
||||
return error_response("No params provided")
|
||||
end
|
||||
|
||||
if not (params.username and params.email and params.password) then
|
||||
return error_response("Missing username, email or password", -32602)
|
||||
return error_response("Missing username, email or password")
|
||||
end
|
||||
|
||||
local unit, err = db:query(
|
||||
"SELECT id, username, email, password, created_at FROM users WHERE email = ? AND username = ? LIMIT 1",
|
||||
{params.email, params.username}
|
||||
"SELECT id, username, email, password, created_at FROM users WHERE email = ? AND username = ? AND deleted = 0 LIMIT 1",
|
||||
{
|
||||
params.email,
|
||||
params.username
|
||||
}
|
||||
)
|
||||
|
||||
if err then
|
||||
@@ -40,7 +49,7 @@ if err then
|
||||
end
|
||||
|
||||
if not unit or #unit == 0 then
|
||||
return error_response("Unit not found", -32604)
|
||||
return error_response("Unit not found")
|
||||
end
|
||||
|
||||
unit = unit[1]
|
||||
@@ -48,16 +57,20 @@ unit = unit[1]
|
||||
local ok = crypt.compare(unit.password, params.password)
|
||||
if not ok then
|
||||
log.warn("Login failed: wrong password for " .. params.username)
|
||||
return error_response("Invalid password", -32605)
|
||||
return error_response("Invalid password")
|
||||
end
|
||||
|
||||
local token = jwt.encode({
|
||||
secret = secret,
|
||||
payload = { session_uuid = session.id,
|
||||
admin_user = params.username,
|
||||
key = sha256.sum(session.request.address .. session.id .. session.request.headers.get("user-agent", "noagent"))
|
||||
},
|
||||
expires_in = 3600
|
||||
})
|
||||
|
||||
session.response.result = {
|
||||
user = {
|
||||
id = unit.id,
|
||||
username = unit.username,
|
||||
email = unit.email,
|
||||
created_at = unit.created_at
|
||||
}
|
||||
access_token = token
|
||||
}
|
||||
|
||||
close_db()
|
||||
109
com/_Auth/PutNewUnit.lua
Normal file
109
com/_Auth/PutNewUnit.lua
Normal file
@@ -0,0 +1,109 @@
|
||||
-- com/PutNewUnit.lua
|
||||
|
||||
---@diagnostic disable: redefined-local
|
||||
local db = require("internal.database.sqlite").connect("db/user-database.db", {log = true})
|
||||
local log = require("internal.log")
|
||||
local session = require("internal.session")
|
||||
local crypt = require("internal.crypt.bcrypt")
|
||||
local jwt = require("internal.crypt.jwt")
|
||||
local sha256 = require("internal.crypt.sha256")
|
||||
|
||||
local params = session.request.params.get()
|
||||
local token = session.request.headers.get("authorization")
|
||||
|
||||
local function close_db()
|
||||
if db then
|
||||
db:close()
|
||||
db = nil
|
||||
end
|
||||
end
|
||||
|
||||
local function error_response(message, code, data)
|
||||
session.response.error = {
|
||||
code = code or nil,
|
||||
message = message,
|
||||
data = data or nil
|
||||
}
|
||||
close_db()
|
||||
end
|
||||
|
||||
if not token or type(token) ~= "string" then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
local prefix = "Bearer "
|
||||
if token:sub(1, #prefix) ~= prefix then
|
||||
return error_response("Invalid Authorization scheme")
|
||||
end
|
||||
|
||||
local access_token = token:sub(#prefix + 1)
|
||||
|
||||
local err, data = jwt.decode(access_token, { secret = require("_config").token() })
|
||||
|
||||
if err or not data then
|
||||
session.response.error = {
|
||||
message = err
|
||||
}
|
||||
return
|
||||
end
|
||||
|
||||
if data.session_uuid ~= session.id then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
if data.key ~= sha256.sum(session.request.address .. session.id .. session.request.headers.get("user-agent", "noagent")) then
|
||||
return error_response("Access denied")
|
||||
end
|
||||
|
||||
if not params then
|
||||
return error_response("no params provided")
|
||||
end
|
||||
|
||||
if not (params.username and params.email and params.password) then
|
||||
return error_response("no username/email/password provided")
|
||||
end
|
||||
|
||||
local hashPass = crypt.generate(params.password, crypt.DefaultCost)
|
||||
|
||||
local existing, err = db:query("SELECT 1 FROM users WHERE deleted = 0 AND (email = ? OR username = ?) LIMIT 1", {
|
||||
params.email,
|
||||
params.username
|
||||
})
|
||||
|
||||
if err ~= nil then
|
||||
log.error("Email check failed: "..tostring(err))
|
||||
return error_response("Database check failed: "..tostring(err))
|
||||
end
|
||||
|
||||
if existing and #existing > 0 then
|
||||
return error_response("Unit already exists")
|
||||
end
|
||||
|
||||
local ctx, err = db:exec(
|
||||
"INSERT INTO users (username, email, password, first_name, last_name, phone_number) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
{
|
||||
params.username,
|
||||
params.email,
|
||||
hashPass,
|
||||
params.first_name or "",
|
||||
params.last_name or "",
|
||||
params.phone_number or ""
|
||||
}
|
||||
)
|
||||
if err ~= nil then
|
||||
log.error("Insert failed: "..tostring(err))
|
||||
return error_response("Insert failed: "..tostring(err))
|
||||
end
|
||||
|
||||
local res, err = ctx:wait()
|
||||
if err ~= nil then
|
||||
log.error("Insert confirmation failed: "..tostring(err))
|
||||
return error_response("Insert confirmation failed: "..tostring(err))
|
||||
end
|
||||
|
||||
session.response.result = {
|
||||
rows_affected = res,
|
||||
message = "Unit created successfully"
|
||||
}
|
||||
|
||||
close_db()
|
||||
@@ -2,15 +2,3 @@
|
||||
package.path = package.path .. ";/usr/lib64/lua/5.1/?.lua;/usr/local/share/lua/5.1/?.lua" .. ";./com/?.lua;"
|
||||
package.cpath = package.cpath .. ";/usr/lib64/lua/5.1/?.so;/usr/local/lib/lua/5.1/?.so"
|
||||
|
||||
print = function() end
|
||||
io.write = function(...) end
|
||||
io.stdout = function() return nil end
|
||||
io.stderr = function() return nil end
|
||||
io.read = function(...) return nil end
|
||||
|
||||
---@type table<string, any>
|
||||
Status = {
|
||||
ok = "ok",
|
||||
error = "error",
|
||||
invalid = "invalid",
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -19,6 +19,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -13,6 +13,10 @@ github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
|
||||
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
// The config composer needs to be in the global scope
|
||||
var Compositor *config.Compositor = config.NewCompositor()
|
||||
|
||||
func InitGlobalLoggerHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
@@ -65,7 +66,8 @@ func InitConfigLoadHook(_ context.Context, cs *corestate.CoreState, x *app.AppX)
|
||||
}
|
||||
}
|
||||
|
||||
func InitUUUDHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
// The hook reads or prepares a persistent uuid for the node
|
||||
func InitUUIDHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
uuid32, err := corestate.GetNodeUUID(filepath.Join(cs.MetaDir, "uuid"))
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
if err := corestate.SetNodeUUID(filepath.Join(cs.NodePath, cs.MetaDir, cs.UUID32DirName)); err != nil {
|
||||
@@ -80,8 +82,11 @@ func InitUUUDHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
x.Log.Fatalf("uuid load error: %s", err)
|
||||
}
|
||||
cs.UUID32 = uuid32
|
||||
corestate.NODE_UUID = uuid32
|
||||
}
|
||||
|
||||
// The hook is responsible for checking the initialization stage
|
||||
// and restarting in some cases
|
||||
func InitRuntimeHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
if *x.Config.Env.ParentStagePID != os.Getpid() {
|
||||
// still pre-init stage
|
||||
@@ -133,6 +138,8 @@ func InitRuntimeHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
}
|
||||
|
||||
// post-init stage
|
||||
// The hook creates a run.lock file, which contains information
|
||||
// about the process and the node, in the runtime directory.
|
||||
func InitRunlockHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
NodeApp.Fallback(func(ctx context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
x.Log.Println("Cleaning up...")
|
||||
@@ -185,6 +192,8 @@ func InitRunlockHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
}
|
||||
}
|
||||
|
||||
// The hook reads the configuration and replaces special expressions
|
||||
// (%tmp% and so on) in string fields with the required data.
|
||||
func InitConfigReplHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
if !slices.Contains(*x.Config.Conf.DisableWarnings, "--WNonStdTmpDir") && os.TempDir() != "/tmp" {
|
||||
x.Log.Printf("%s: %s", colors.PrintWarn(), "Non-standard value specified for temporary directory")
|
||||
@@ -209,6 +218,8 @@ func InitConfigReplHook(_ context.Context, cs *corestate.CoreState, x *app.AppX)
|
||||
}
|
||||
}
|
||||
|
||||
// The hook is responsible for outputting the
|
||||
// final config and asking for confirmation.
|
||||
func InitConfigPrintHook(ctx context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
if *x.Config.Conf.Node.ShowConfig {
|
||||
fmt.Printf("Configuration from %s:\n", x.Config.CMDLine.Run.ConfigPath)
|
||||
@@ -239,6 +250,8 @@ func InitSLogHook(_ context.Context, cs *corestate.CoreState, x *app.AppX) {
|
||||
*x.SLog = *newSlog
|
||||
}
|
||||
|
||||
// The method goes through the entire config structure through
|
||||
// reflection and replaces string fields with the required ones.
|
||||
func processConfig(conf any, replacements map[string]any) error {
|
||||
val := reflect.ValueOf(conf)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
|
||||
@@ -33,7 +33,7 @@ var NodeApp = app.New()
|
||||
func Run(cmd *cobra.Command, args []string) {
|
||||
NodeApp.InitialHooks(
|
||||
InitGlobalLoggerHook, InitCorestateHook, InitConfigLoadHook,
|
||||
InitUUUDHook, InitRuntimeHook, InitRunlockHook,
|
||||
InitUUIDHook, InitRuntimeHook, InitRunlockHook,
|
||||
InitConfigReplHook, InitConfigPrintHook, InitSLogHook,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package corestate
|
||||
|
||||
var NODE_UUID string
|
||||
|
||||
type Stage string
|
||||
|
||||
const (
|
||||
|
||||
1
internal/engine/lua/handler.go
Normal file
1
internal/engine/lua/handler.go
Normal file
@@ -0,0 +1 @@
|
||||
package lua
|
||||
35
internal/engine/lua/pool.go
Normal file
35
internal/engine/lua/pool.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package lua
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
type LuaPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func NewLuaPool() *LuaPool {
|
||||
return &LuaPool{
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
L := lua.NewState()
|
||||
|
||||
return L
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (lp *LuaPool) Get() *lua.LState {
|
||||
return lp.pool.Get().(*lua.LState)
|
||||
}
|
||||
|
||||
func (lp *LuaPool) Put(L *lua.LState) {
|
||||
L.Close()
|
||||
|
||||
newL := lua.NewState()
|
||||
|
||||
lp.pool.Put(newL)
|
||||
}
|
||||
26
internal/engine/lua/types.go
Normal file
26
internal/engine/lua/types.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package lua
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/engine/app"
|
||||
"github.com/akyaiy/GoSally-mvp/internal/server/rpc"
|
||||
)
|
||||
|
||||
type LuaEngineDeps struct {
|
||||
HttpRequest *http.Request
|
||||
JSONRPCRequest *rpc.RPCRequest
|
||||
SessionUUID string
|
||||
ScriptPath string
|
||||
}
|
||||
|
||||
type LuaEngineContract interface {
|
||||
Handle(deps *LuaEngineDeps) *rpc.RPCResponse
|
||||
|
||||
}
|
||||
|
||||
type LuaEngine struct {
|
||||
x *app.AppX
|
||||
cs *corestate.CoreState
|
||||
}
|
||||
@@ -20,18 +20,14 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
sessionUUID := r.Header.Get("X-Session-UUID")
|
||||
if sessionUUID == "" {
|
||||
sessionUUID = uuid.New().String()
|
||||
|
||||
}
|
||||
gs.x.SLog.Debug("new request", slog.String("session-uuid", sessionUUID), slog.Group("connection", slog.String("ip", r.RemoteAddr)))
|
||||
|
||||
w.Header().Set("X-Session-UUID", sessionUUID)
|
||||
if !gs.sm.Add(sessionUUID) {
|
||||
gs.x.SLog.Debug("session is busy", slog.String("session-uuid", sessionUUID))
|
||||
rpc.WriteError(w, &rpc.RPCResponse{
|
||||
Error: map[string]any{
|
||||
"code": rpc.ErrSessionIsBusy,
|
||||
"message": rpc.ErrSessionIsBusyS,
|
||||
},
|
||||
})
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrSessionIsBusy, rpc.ErrSessionIsBusyS, nil, nil))
|
||||
return
|
||||
}
|
||||
defer gs.sm.Delete(sessionUUID)
|
||||
@@ -40,14 +36,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
gs.x.SLog.Debug("failed to read body", slog.String("err", err.Error()))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
rpc.WriteError(w, &rpc.RPCResponse{
|
||||
JSONRPC: rpc.JSONRPCVersion,
|
||||
ID: nil,
|
||||
Error: map[string]any{
|
||||
"code": rpc.ErrInternalError,
|
||||
"message": rpc.ErrInternalErrorS,
|
||||
},
|
||||
})
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, nil))
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInternalErrorS))
|
||||
return
|
||||
}
|
||||
@@ -60,14 +49,7 @@ func (gs *GatewayServer) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.Unmarshal(body, &single); err != nil {
|
||||
gs.x.SLog.Debug("failed to parse json", slog.String("err", err.Error()))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
rpc.WriteError(w, &rpc.RPCResponse{
|
||||
JSONRPC: rpc.JSONRPCVersion,
|
||||
ID: nil,
|
||||
Error: map[string]any{
|
||||
"code": rpc.ErrParseError,
|
||||
"message": rpc.ErrParseErrorS,
|
||||
},
|
||||
})
|
||||
rpc.WriteError(w, rpc.NewError(rpc.ErrParseError, rpc.ErrParseErrorS, nil, nil))
|
||||
gs.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrParseErrorS))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -13,8 +13,16 @@ type RPCRequest struct {
|
||||
type RPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID *json.RawMessage `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Result any `json:"result,omitzero"`
|
||||
Error any `json:"error,omitzero"`
|
||||
Data *RPCData `json:"data,omitzero"`
|
||||
}
|
||||
|
||||
type RPCData struct {
|
||||
ResponsibleNode string `json:"responsible-node,omitempty"`
|
||||
Salt string `json:"salt,omitempty"`
|
||||
Checksum string `json:"checksum-md5,omitempty"`
|
||||
NewSessionUUID string `json:"new-session-uuid,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -1,26 +1,50 @@
|
||||
package rpc
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/akyaiy/GoSally-mvp/internal/core/corestate"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func generateChecksum(result any) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x", md5.Sum(data))
|
||||
}
|
||||
|
||||
func generateSalt() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
func GetData(data any) *RPCData {
|
||||
return &RPCData{
|
||||
Salt: generateSalt(),
|
||||
ResponsibleNode: corestate.NODE_UUID,
|
||||
Checksum: generateChecksum(data),
|
||||
}
|
||||
}
|
||||
|
||||
func NewError(code int, message string, data any, id *json.RawMessage) *RPCResponse {
|
||||
if data != nil {
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Error: map[string]any{
|
||||
Error := make(map[string]any)
|
||||
Error = map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &RPCResponse{
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Error: map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
Error: Error,
|
||||
Data: GetData(Error),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,5 +53,6 @@ func NewResponse(result any, id *json.RawMessage) *RPCResponse {
|
||||
JSONRPC: JSONRPCVersion,
|
||||
ID: id,
|
||||
Result: result,
|
||||
Data: GetData(result),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func getDBMutex(dbPath string) *sync.RWMutex {
|
||||
return mtx
|
||||
}
|
||||
|
||||
func loadDBMod(llog *slog.Logger) func(*lua.LState) int {
|
||||
func loadDBMod(llog *slog.Logger, sid string) func(*lua.LState) int {
|
||||
return func(L *lua.LState) int {
|
||||
llog.Debug("import module db-sqlite")
|
||||
dbMod := L.NewTable()
|
||||
@@ -82,10 +82,11 @@ func loadDBMod(llog *slog.Logger) func(*lua.LState) int {
|
||||
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
|
||||
"exec": dbExec,
|
||||
"query": dbQuery,
|
||||
"query_row": dbQueryRow,
|
||||
"close": dbClose,
|
||||
}))
|
||||
|
||||
L.SetField(dbMod, "__gosally_internal", lua.LString("0"))
|
||||
L.SetField(dbMod, "__seed", lua.LString(sid))
|
||||
L.Push(dbMod)
|
||||
return 1
|
||||
}
|
||||
@@ -162,7 +163,7 @@ func dbExec(L *lua.LState) int {
|
||||
var result lua.LValue = lua.LNil
|
||||
var errorMsg lua.LValue = lua.LNil
|
||||
|
||||
L.SetField(ctx, "wait", L.NewFunction(func(lL *lua.LState) int {
|
||||
L.SetField(ctx, "wait", L.NewFunction(func(L *lua.LState) int {
|
||||
res := <-resCh
|
||||
L.SetField(ctx, "done", lua.LBool(true))
|
||||
|
||||
@@ -175,35 +176,35 @@ func dbExec(L *lua.LState) int {
|
||||
}
|
||||
|
||||
if res.err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(res.err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(res.err.Error()))
|
||||
return 2
|
||||
}
|
||||
lL.Push(lua.LNumber(res.rowsAffected))
|
||||
lL.Push(lua.LNil)
|
||||
L.Push(lua.LNumber(res.rowsAffected))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}))
|
||||
|
||||
L.SetField(ctx, "check", L.NewFunction(func(lL *lua.LState) int {
|
||||
L.SetField(ctx, "check", L.NewFunction(func(L *lua.LState) int {
|
||||
select {
|
||||
case res := <-resCh:
|
||||
lL.SetField(ctx, "done", lua.LBool(true))
|
||||
L.SetField(ctx, "done", lua.LBool(true))
|
||||
if res.err != nil {
|
||||
errorMsg = lua.LString(res.err.Error())
|
||||
result = lua.LNil
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(res.err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(res.err.Error()))
|
||||
return 2
|
||||
} else {
|
||||
result = lua.LNumber(res.rowsAffected)
|
||||
errorMsg = lua.LNil
|
||||
lL.Push(lua.LNumber(res.rowsAffected))
|
||||
lL.Push(lua.LNil)
|
||||
L.Push(lua.LNumber(res.rowsAffected))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
default:
|
||||
lL.Push(result)
|
||||
lL.Push(errorMsg)
|
||||
L.Push(result)
|
||||
L.Push(errorMsg)
|
||||
return 2
|
||||
}
|
||||
}))
|
||||
@@ -213,6 +214,102 @@ func dbExec(L *lua.LState) int {
|
||||
return 2
|
||||
}
|
||||
|
||||
func dbQueryRow(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
if !ok {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("invalid database connection"))
|
||||
return 2
|
||||
}
|
||||
|
||||
query := L.CheckString(2)
|
||||
|
||||
var args []any
|
||||
if L.GetTop() >= 3 {
|
||||
params := L.CheckTable(3)
|
||||
params.ForEach(func(k lua.LValue, v lua.LValue) {
|
||||
args = append(args, ConvertLuaTypesToGolang(v))
|
||||
})
|
||||
}
|
||||
|
||||
if conn.log {
|
||||
conn.logger.Info("DB QueryRow",
|
||||
slog.String("query", query),
|
||||
slog.Any("params", args))
|
||||
}
|
||||
|
||||
mtx := getDBMutex(conn.dbPath)
|
||||
mtx.RLock()
|
||||
defer mtx.RUnlock()
|
||||
|
||||
db, err := sql.Open("sqlite", conn.dbPath+"?_busy_timeout=5000&_journal_mode=WAL&_sync=NORMAL&_cache_size=-10000")
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
row := db.QueryRow(query, args...)
|
||||
|
||||
columns := []string{}
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("prepare failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.Query(args...)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("query failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
defer rows.Close()
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("get columns failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
for _, c := range cols {
|
||||
columns = append(columns, c)
|
||||
}
|
||||
|
||||
colCount := len(columns)
|
||||
values := make([]any, colCount)
|
||||
valuePtrs := make([]any, colCount)
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
err = row.Scan(valuePtrs...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
L.Push(lua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(fmt.Sprintf("scan failed: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
rowTable := L.NewTable()
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if val == nil {
|
||||
L.SetField(rowTable, col, lua.LNil)
|
||||
} else {
|
||||
L.SetField(rowTable, col, ConvertGolangTypesToLua(L, val))
|
||||
}
|
||||
}
|
||||
|
||||
L.Push(rowTable)
|
||||
return 1
|
||||
}
|
||||
|
||||
func dbQuery(L *lua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
conn, ok := ud.Value.(*DBConnection)
|
||||
|
||||
@@ -24,6 +24,16 @@ func (h *HandlerV1) Handle(_ context.Context, sid string, r *http.Request, req *
|
||||
return rpc.NewError(rpc.ErrMethodNotFound, rpc.ErrMethodNotFoundS, nil, req.ID)
|
||||
}
|
||||
}
|
||||
|
||||
switch req.Params.(type) {
|
||||
case map[string]any, []any, nil:
|
||||
return h.handleLUA(sid, r, req, method)
|
||||
default:
|
||||
// JSON-RPC 2.0 Specification:
|
||||
// https://www.jsonrpc.org/specification#parameter_structures
|
||||
//
|
||||
// "params" MUST be either an *array* or an *object* if included.
|
||||
// Any other type (e.g., a number, string, or boolean) is INVALID.
|
||||
h.x.SLog.Info("invalid request received", slog.String("issue", rpc.ErrInvalidParamsS))
|
||||
return rpc.NewError(rpc.ErrInvalidParams, rpc.ErrInvalidParamsS, nil, req.ID)
|
||||
}
|
||||
}
|
||||
|
||||
86
internal/server/sv1/jwt.go
Normal file
86
internal/server/sv1/jwt.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package sv1
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func loadJWTMod(llog *slog.Logger, sid string) func(*lua.LState) int {
|
||||
return func(L *lua.LState) int {
|
||||
llog.Debug("import module jwt")
|
||||
jwtMod := L.NewTable()
|
||||
|
||||
L.SetField(jwtMod, "encode", L.NewFunction(jwtEncode))
|
||||
L.SetField(jwtMod, "decode", L.NewFunction(jwtDecode))
|
||||
|
||||
L.SetField(jwtMod, "__seed", lua.LString(sid))
|
||||
L.Push(jwtMod)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func jwtEncode(L *lua.LState) int {
|
||||
payloadTbl := L.CheckTable(1)
|
||||
secret := L.GetField(payloadTbl, "secret").String()
|
||||
payload := L.GetField(payloadTbl, "payload").(*lua.LTable)
|
||||
expiresIn := L.GetField(payloadTbl, "expires_in")
|
||||
expDuration := time.Hour
|
||||
|
||||
if expiresIn.Type() == lua.LTNumber {
|
||||
floatVal := ConvertLuaTypesToGolang(expiresIn).(float64)
|
||||
expDuration = time.Duration(floatVal) * time.Second
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
payload.ForEach(func(key, value lua.LValue) {
|
||||
claims[key.String()] = ConvertLuaTypesToGolang(value)
|
||||
})
|
||||
claims["iat"] = time.Now().Unix()
|
||||
claims["exp"] = time.Now().Add(expDuration).Unix()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
L.Push(lua.LString(signedToken))
|
||||
return 1
|
||||
}
|
||||
|
||||
func jwtDecode(L *lua.LState) int {
|
||||
tokenString := L.CheckString(1)
|
||||
optsTbl := L.OptTable(2, L.NewTable())
|
||||
secret := L.GetField(optsTbl, "secret").String()
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
L.Push(lua.LString("Invalid token: " + err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
L.Push(lua.LString("Invalid claims"))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}
|
||||
|
||||
luaTable := L.NewTable()
|
||||
for k, v := range claims {
|
||||
luaTable.RawSetString(k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
|
||||
L.Push(lua.LNil)
|
||||
L.Push(luaTable)
|
||||
return 2
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package sv1
|
||||
// TODO: make a lua state pool using sync.Pool
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -41,43 +43,256 @@ func addInitiatorHeaders(sid string, req *http.Request, headers http.Header) {
|
||||
// I will be only glad.
|
||||
// TODO: make this huge function more harmonious by dividing responsibilities
|
||||
func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest, path string) *rpc.RPCResponse {
|
||||
var __exit = -1
|
||||
|
||||
llog := h.x.SLog.With(slog.String("session-id", sid))
|
||||
llog.Debug("handling LUA")
|
||||
L := lua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
osMod := L.GetGlobal("os").(*lua.LTable)
|
||||
L.SetField(osMod, "exit", lua.LNil)
|
||||
|
||||
ioMod := L.GetGlobal("io").(*lua.LTable)
|
||||
for _, k := range []string{"write", "output", "flush", "read", "input"} {
|
||||
ioMod.RawSetString(k, lua.LNil)
|
||||
}
|
||||
L.Env.RawSetString("print", lua.LNil)
|
||||
|
||||
for _, name := range []string{"stdout", "stderr", "stdin"} {
|
||||
stream := ioMod.RawGetString(name)
|
||||
if t, ok := stream.(*lua.LUserData); ok {
|
||||
t.Metatable = lua.LNil
|
||||
}
|
||||
}
|
||||
|
||||
seed := rand.Int()
|
||||
|
||||
loadSessionMod := func(lL *lua.LState) int {
|
||||
loadSessionMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module session", slog.String("script", path))
|
||||
sessionMod := lL.NewTable()
|
||||
inTable := lL.NewTable()
|
||||
paramsTable := lL.NewTable()
|
||||
if fetchedParams, ok := req.Params.(map[string]any); ok {
|
||||
for k, v := range fetchedParams {
|
||||
lL.SetField(paramsTable, k, ConvertGolangTypesToLua(lL, v))
|
||||
sessionMod := L.NewTable()
|
||||
inTable := L.NewTable()
|
||||
paramsTable := L.NewTable()
|
||||
headersTable := L.NewTable()
|
||||
|
||||
fetchedHeadersTable := L.NewTable()
|
||||
for k, v := range r.Header {
|
||||
L.SetField(fetchedHeadersTable, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
|
||||
headersGetter := L.NewFunction(func(L *lua.LState) int {
|
||||
path := L.OptString(1, "")
|
||||
def := L.Get(2)
|
||||
|
||||
get := func(path string) lua.LValue {
|
||||
if path == "" {
|
||||
return fetchedHeadersTable
|
||||
}
|
||||
fetched := r.Header.Get(path)
|
||||
if fetched == "" {
|
||||
return lua.LNil
|
||||
}
|
||||
return lua.LString(fetched)
|
||||
}
|
||||
val := get(path)
|
||||
if val == lua.LNil && def != lua.LNil {
|
||||
L.Push(def)
|
||||
} else {
|
||||
L.Push(val)
|
||||
}
|
||||
return 1
|
||||
})
|
||||
|
||||
L.SetField(headersTable, "__fetched", fetchedHeadersTable)
|
||||
|
||||
L.SetField(headersTable, "get", headersGetter)
|
||||
L.SetField(inTable, "headers", headersTable)
|
||||
|
||||
fetchedParamsTable := L.NewTable()
|
||||
switch params := req.Params.(type) {
|
||||
case map[string]any:
|
||||
for k, v := range params {
|
||||
L.SetField(fetchedParamsTable, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
case []any:
|
||||
for i, v := range params {
|
||||
fetchedParamsTable.RawSetInt(i+1, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
}
|
||||
lL.SetField(inTable, "params", paramsTable)
|
||||
|
||||
outTable := lL.NewTable()
|
||||
resultTable := lL.NewTable()
|
||||
lL.SetField(outTable, "result", resultTable)
|
||||
paramsGetter := L.NewFunction(func(L *lua.LState) int {
|
||||
path := L.OptString(1, "")
|
||||
def := L.Get(2)
|
||||
|
||||
lL.SetField(inTable, "address", lua.LString(r.RemoteAddr))
|
||||
lL.SetField(sessionMod, "request", inTable)
|
||||
lL.SetField(sessionMod, "response", outTable)
|
||||
get := func(tbl *lua.LTable, path string) lua.LValue {
|
||||
if path == "" {
|
||||
return tbl
|
||||
}
|
||||
current := tbl
|
||||
parts := strings.Split(path, ".")
|
||||
size := len(parts)
|
||||
for index, key := range parts {
|
||||
val := current.RawGetString(key)
|
||||
if tblVal, ok := val.(*lua.LTable); ok {
|
||||
current = tblVal
|
||||
} else {
|
||||
if index == size-1 {
|
||||
return val
|
||||
}
|
||||
return lua.LNil
|
||||
}
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
lL.SetField(sessionMod, "id", lua.LString(sid))
|
||||
paramsTbl := L.GetField(paramsTable, "__fetched") //
|
||||
val := get(paramsTbl.(*lua.LTable), path) //
|
||||
if val == lua.LNil && def != lua.LNil {
|
||||
L.Push(def)
|
||||
} else {
|
||||
L.Push(val)
|
||||
}
|
||||
return 1
|
||||
})
|
||||
L.SetField(paramsTable, "__fetched", fetchedParamsTable)
|
||||
|
||||
lL.SetField(sessionMod, "__gosally_internal", lua.LString(fmt.Sprint(seed)))
|
||||
lL.Push(sessionMod)
|
||||
L.SetField(paramsTable, "get", paramsGetter)
|
||||
L.SetField(inTable, "params", paramsTable)
|
||||
|
||||
outTable := L.NewTable()
|
||||
scriptDataTable := L.NewTable()
|
||||
L.SetField(outTable, "__script_data", scriptDataTable)
|
||||
|
||||
L.SetField(inTable, "address", lua.LString(r.RemoteAddr))
|
||||
|
||||
L.SetField(sessionMod, "throw_error", L.NewFunction(func(L *lua.LState) int {
|
||||
arg := L.Get(1)
|
||||
var msg string
|
||||
switch arg.Type() {
|
||||
case lua.LTString:
|
||||
msg = arg.String()
|
||||
case lua.LTNumber:
|
||||
msg = strconv.FormatFloat(float64(arg.(lua.LNumber)), 'f', -1, 64)
|
||||
default:
|
||||
L.ArgError(1, "expected string or number")
|
||||
return 0
|
||||
}
|
||||
|
||||
L.RaiseError("%s", msg)
|
||||
return 0
|
||||
}))
|
||||
|
||||
resTable := L.NewTable()
|
||||
L.SetField(scriptDataTable, "result", resTable)
|
||||
L.SetField(outTable, "send", L.NewFunction(func(L *lua.LState) int {
|
||||
res := L.Get(1)
|
||||
if res == lua.LNil {
|
||||
__exit = 0
|
||||
L.RaiseError("__successfull")
|
||||
return 0
|
||||
}
|
||||
|
||||
resFTable := scriptDataTable.RawGetString("result")
|
||||
if resPTable, ok := res.(*lua.LTable); ok {
|
||||
resPTable.ForEach(func(key, value lua.LValue) {
|
||||
L.SetField(resFTable, key.String(), value)
|
||||
})
|
||||
} else {
|
||||
L.SetField(scriptDataTable, "result", res)
|
||||
}
|
||||
|
||||
__exit = 0
|
||||
L.RaiseError("__successfull")
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(outTable, "set", L.NewFunction(func(L *lua.LState) int {
|
||||
res := L.Get(1)
|
||||
if res == lua.LNil {
|
||||
return 0
|
||||
}
|
||||
|
||||
resFTable := scriptDataTable.RawGetString("result")
|
||||
if resPTable, ok := res.(*lua.LTable); ok {
|
||||
resPTable.ForEach(func(key, value lua.LValue) {
|
||||
L.SetField(resFTable, key.String(), value)
|
||||
})
|
||||
} else {
|
||||
L.SetField(scriptDataTable, "result", res)
|
||||
}
|
||||
return 0
|
||||
}))
|
||||
|
||||
errTable := L.NewTable()
|
||||
L.SetField(scriptDataTable, "error", errTable)
|
||||
L.SetField(outTable, "send_error", L.NewFunction(func(L *lua.LState) int {
|
||||
var params [3]lua.LValue
|
||||
for i := range 3 {
|
||||
params[i] = L.Get(i + 1)
|
||||
}
|
||||
if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
for _, v := range params {
|
||||
switch v.Type() {
|
||||
case lua.LTNumber:
|
||||
if n, ok := v.(lua.LNumber); ok {
|
||||
L.SetField(errTable, "code", n)
|
||||
}
|
||||
case lua.LTString:
|
||||
if s, ok := v.(lua.LString); ok {
|
||||
L.SetField(errTable, "message", s)
|
||||
}
|
||||
case lua.LTTable:
|
||||
if tbl, ok := v.(*lua.LTable); ok {
|
||||
L.SetField(errTable, "data", tbl)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__exit = 1
|
||||
L.RaiseError("__unsuccessfull")
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(outTable, "set_error", L.NewFunction(func(L *lua.LState) int {
|
||||
var params [3]lua.LValue
|
||||
for i := range 3 {
|
||||
params[i] = L.Get(i + 1)
|
||||
}
|
||||
if errTable, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
for _, v := range params {
|
||||
switch v.Type() {
|
||||
case lua.LTNumber:
|
||||
if n, ok := v.(lua.LNumber); ok {
|
||||
L.SetField(errTable, "code", n)
|
||||
}
|
||||
case lua.LTString:
|
||||
if s, ok := v.(lua.LString); ok {
|
||||
L.SetField(errTable, "message", s)
|
||||
}
|
||||
case lua.LTTable:
|
||||
if tbl, ok := v.(*lua.LTable); ok {
|
||||
L.SetField(errTable, "data", tbl)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}))
|
||||
|
||||
L.SetField(sessionMod, "request", inTable)
|
||||
L.SetField(sessionMod, "response", outTable)
|
||||
|
||||
L.SetField(sessionMod, "id", lua.LString(sid))
|
||||
|
||||
L.SetField(sessionMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(sessionMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadLogMod := func(lL *lua.LState) int {
|
||||
loadLogMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module log", slog.String("script", path))
|
||||
logMod := lL.NewTable()
|
||||
logMod := L.NewTable()
|
||||
|
||||
logFuncs := map[string]func(string, ...any){
|
||||
"info": llog.Info,
|
||||
@@ -88,8 +303,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
|
||||
for name, logFunc := range logFuncs {
|
||||
fun := logFunc
|
||||
lL.SetField(logMod, name, lL.NewFunction(func(lL *lua.LState) int {
|
||||
msg := lL.Get(1)
|
||||
L.SetField(logMod, name, L.NewFunction(func(L *lua.LState) int {
|
||||
msg := L.Get(1)
|
||||
converted := ConvertLuaTypesToGolang(msg)
|
||||
fun(fmt.Sprintf("the script says: %s", converted), slog.String("script", path))
|
||||
return 0
|
||||
@@ -105,8 +320,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
{"event_error", h.x.Log.Printf, colors.PrintError},
|
||||
{"event_warn", h.x.Log.Printf, colors.PrintWarn},
|
||||
} {
|
||||
lL.SetField(logMod, fn.field, lL.NewFunction(func(lL *lua.LState) int {
|
||||
msg := lL.Get(1)
|
||||
L.SetField(logMod, fn.field, L.NewFunction(func(L *lua.LState) int {
|
||||
msg := L.Get(1)
|
||||
converted := ConvertLuaTypesToGolang(msg)
|
||||
if fn.color != nil {
|
||||
h.x.Log.Printf("%s: %s: %s", fn.color(), path, converted)
|
||||
@@ -117,24 +332,24 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
}))
|
||||
}
|
||||
|
||||
lL.SetField(logMod, "__gosally_internal", lua.LString(fmt.Sprint(seed)))
|
||||
lL.Push(logMod)
|
||||
L.SetField(logMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(logMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadNetMod := func(lL *lua.LState) int {
|
||||
loadNetMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module net", slog.String("script", path))
|
||||
netMod := lL.NewTable()
|
||||
netModhttp := lL.NewTable()
|
||||
netMod := L.NewTable()
|
||||
netModhttp := L.NewTable()
|
||||
|
||||
lL.SetField(netModhttp, "get_request", lL.NewFunction(func(lL *lua.LState) int {
|
||||
logRequest := lL.ToBool(1)
|
||||
url := lL.ToString(2)
|
||||
L.SetField(netModhttp, "get_request", L.NewFunction(func(L *lua.LState) int {
|
||||
logRequest := L.ToBool(1)
|
||||
url := L.ToString(2)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
@@ -143,16 +358,16 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
@@ -166,34 +381,34 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
)
|
||||
}
|
||||
|
||||
result := lL.NewTable()
|
||||
lL.SetField(result, "status", lua.LNumber(resp.StatusCode))
|
||||
lL.SetField(result, "status_text", lua.LString(resp.Status))
|
||||
lL.SetField(result, "body", lua.LString(body))
|
||||
lL.SetField(result, "content_length", lua.LNumber(resp.ContentLength))
|
||||
result := L.NewTable()
|
||||
L.SetField(result, "status", lua.LNumber(resp.StatusCode))
|
||||
L.SetField(result, "status_text", lua.LString(resp.Status))
|
||||
L.SetField(result, "body", lua.LString(body))
|
||||
L.SetField(result, "content_length", lua.LNumber(resp.ContentLength))
|
||||
|
||||
headers := lL.NewTable()
|
||||
headers := L.NewTable()
|
||||
for k, v := range resp.Header {
|
||||
lL.SetField(headers, k, ConvertGolangTypesToLua(lL, v))
|
||||
L.SetField(headers, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
lL.SetField(result, "headers", headers)
|
||||
L.SetField(result, "headers", headers)
|
||||
|
||||
lL.Push(result)
|
||||
L.Push(result)
|
||||
return 1
|
||||
}))
|
||||
|
||||
lL.SetField(netModhttp, "post_request", lL.NewFunction(func(lL *lua.LState) int {
|
||||
logRequest := lL.ToBool(1)
|
||||
url := lL.ToString(2)
|
||||
contentType := lL.ToString(3)
|
||||
payload := lL.ToString(4)
|
||||
L.SetField(netModhttp, "post_request", L.NewFunction(func(L *lua.LState) int {
|
||||
logRequest := L.ToBool(1)
|
||||
url := L.ToString(2)
|
||||
contentType := L.ToString(3)
|
||||
payload := L.ToString(4)
|
||||
|
||||
body := strings.NewReader(payload)
|
||||
|
||||
req, err := http.NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
@@ -204,16 +419,16 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString(err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
@@ -228,47 +443,47 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
)
|
||||
}
|
||||
|
||||
result := lL.NewTable()
|
||||
lL.SetField(result, "status", lua.LNumber(resp.StatusCode))
|
||||
lL.SetField(result, "status_text", lua.LString(resp.Status))
|
||||
lL.SetField(result, "body", lua.LString(respBody))
|
||||
lL.SetField(result, "content_length", lua.LNumber(resp.ContentLength))
|
||||
result := L.NewTable()
|
||||
L.SetField(result, "status", lua.LNumber(resp.StatusCode))
|
||||
L.SetField(result, "status_text", lua.LString(resp.Status))
|
||||
L.SetField(result, "body", lua.LString(respBody))
|
||||
L.SetField(result, "content_length", lua.LNumber(resp.ContentLength))
|
||||
|
||||
headers := lL.NewTable()
|
||||
headers := L.NewTable()
|
||||
for k, v := range resp.Header {
|
||||
lL.SetField(headers, k, ConvertGolangTypesToLua(lL, v))
|
||||
L.SetField(headers, k, ConvertGolangTypesToLua(L, v))
|
||||
}
|
||||
lL.SetField(result, "headers", headers)
|
||||
L.SetField(result, "headers", headers)
|
||||
|
||||
lL.Push(result)
|
||||
L.Push(result)
|
||||
return 1
|
||||
}))
|
||||
|
||||
lL.SetField(netMod, "http", netModhttp)
|
||||
L.SetField(netMod, "http", netModhttp)
|
||||
|
||||
lL.SetField(netMod, "__gosally_internal", lua.LString(fmt.Sprint(seed)))
|
||||
lL.Push(netMod)
|
||||
L.SetField(netMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(netMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadCryptbcryptMod := func(lL *lua.LState) int {
|
||||
loadCryptbcryptMod := func(L *lua.LState) int {
|
||||
llog.Debug("import module crypt.bcrypt", slog.String("script", path))
|
||||
bcryptMod := lL.NewTable()
|
||||
bcryptMod := L.NewTable()
|
||||
|
||||
lL.SetField(bcryptMod, "MinCost", lua.LNumber(bcrypt.MinCost))
|
||||
lL.SetField(bcryptMod, "MaxCost", lua.LNumber(bcrypt.MaxCost))
|
||||
lL.SetField(bcryptMod, "DefaultCost", lua.LNumber(bcrypt.DefaultCost))
|
||||
L.SetField(bcryptMod, "MinCost", lua.LNumber(bcrypt.MinCost))
|
||||
L.SetField(bcryptMod, "MaxCost", lua.LNumber(bcrypt.MaxCost))
|
||||
L.SetField(bcryptMod, "DefaultCost", lua.LNumber(bcrypt.DefaultCost))
|
||||
|
||||
lL.SetField(bcryptMod, "generate", lL.NewFunction(func(l *lua.LState) int {
|
||||
password := ConvertLuaTypesToGolang(lL.Get(1))
|
||||
L.SetField(bcryptMod, "generate", L.NewFunction(func(l *lua.LState) int {
|
||||
password := ConvertLuaTypesToGolang(L.Get(1))
|
||||
passwordStr, ok := password.(string)
|
||||
if !ok {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString("error: password must be a string"))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("error: password must be a string"))
|
||||
return 2
|
||||
}
|
||||
|
||||
cost := ConvertLuaTypesToGolang(lL.Get(2))
|
||||
cost := ConvertLuaTypesToGolang(L.Get(2))
|
||||
costInt := bcrypt.DefaultCost
|
||||
switch v := cost.(type) {
|
||||
case int:
|
||||
@@ -278,56 +493,78 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
case nil:
|
||||
// ok, use DefaultCost
|
||||
default:
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString("error: cost must be an integer"))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("error: cost must be an integer"))
|
||||
return 2
|
||||
}
|
||||
|
||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(passwordStr), costInt)
|
||||
if err != nil {
|
||||
lL.Push(lua.LNil)
|
||||
lL.Push(lua.LString("error: " + err.Error()))
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString("error: " + err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
lL.Push(lua.LString(string(hashBytes)))
|
||||
lL.Push(lua.LNil)
|
||||
L.Push(lua.LString(string(hashBytes)))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}))
|
||||
|
||||
lL.SetField(bcryptMod, "compare", lL.NewFunction(func(l *lua.LState) int {
|
||||
hash := ConvertLuaTypesToGolang(lL.Get(1))
|
||||
L.SetField(bcryptMod, "compare", L.NewFunction(func(l *lua.LState) int {
|
||||
hash := ConvertLuaTypesToGolang(L.Get(1))
|
||||
hashStr, ok := hash.(string)
|
||||
if !ok {
|
||||
lL.Push(lua.LString("error: hash must be a string"))
|
||||
L.Push(lua.LString("error: hash must be a string"))
|
||||
return 1
|
||||
}
|
||||
password := ConvertLuaTypesToGolang(lL.Get(2))
|
||||
password := ConvertLuaTypesToGolang(L.Get(2))
|
||||
passwordStr, ok := password.(string)
|
||||
if !ok {
|
||||
lL.Push(lua.LString("error: password must be a string"))
|
||||
L.Push(lua.LString("error: password must be a string"))
|
||||
return 1
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashStr), []byte(passwordStr))
|
||||
if err != nil {
|
||||
lL.Push(lua.LFalse)
|
||||
L.Push(lua.LFalse)
|
||||
return 1
|
||||
}
|
||||
lL.Push(lua.LTrue)
|
||||
L.Push(lua.LTrue)
|
||||
return 1
|
||||
}))
|
||||
|
||||
lL.SetField(bcryptMod, "__gosally_internal", lua.LString(fmt.Sprint(seed)))
|
||||
lL.Push(bcryptMod)
|
||||
L.SetField(bcryptMod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(bcryptMod)
|
||||
return 1
|
||||
}
|
||||
|
||||
loadCryptbsha256Mod := func(L *lua.LState) int {
|
||||
llog.Debug("import module crypt.sha256", slog.String("script", path))
|
||||
sha265mod := L.NewTable()
|
||||
|
||||
L.SetField(sha265mod, "sum", L.NewFunction(func(l *lua.LState) int {
|
||||
data := ConvertLuaTypesToGolang(L.Get(1))
|
||||
var dataStr = fmt.Sprint(data)
|
||||
|
||||
hash := sha256.Sum256([]byte(dataStr))
|
||||
|
||||
L.Push(lua.LString(hex.EncodeToString(hash[:])))
|
||||
L.Push(lua.LNil)
|
||||
return 2
|
||||
}))
|
||||
|
||||
L.SetField(sha265mod, "__seed", lua.LString(fmt.Sprint(seed)))
|
||||
L.Push(sha265mod)
|
||||
return 1
|
||||
}
|
||||
|
||||
L.PreloadModule("internal.session", loadSessionMod)
|
||||
L.PreloadModule("internal.log", loadLogMod)
|
||||
L.PreloadModule("internal.net", loadNetMod)
|
||||
L.PreloadModule("internal.database-sqlite", loadDBMod(llog))
|
||||
L.PreloadModule("internal.database.sqlite", loadDBMod(llog, fmt.Sprint(seed)))
|
||||
L.PreloadModule("internal.crypt.bcrypt", loadCryptbcryptMod)
|
||||
L.PreloadModule("internal.crypt.sha256", loadCryptbsha256Mod)
|
||||
L.PreloadModule("internal.crypt.jwt", loadJWTMod(llog, fmt.Sprint(seed)))
|
||||
|
||||
llog.Debug("preparing environment")
|
||||
prep := filepath.Join(*h.x.Config.Conf.Node.ComDir, "_prepare.lua")
|
||||
@@ -338,7 +575,8 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
}
|
||||
}
|
||||
llog.Debug("executing script", slog.String("script", path))
|
||||
if err := L.DoFile(path); err != nil {
|
||||
err := L.DoFile(path)
|
||||
if err != nil && __exit != 0 && __exit != 1 {
|
||||
llog.Error("script error", slog.String("script", path), slog.String("error", err.Error()))
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
@@ -360,17 +598,13 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
sessionVal := loadedTbl.RawGetString("internal.session")
|
||||
sessionTbl, ok := sessionVal.(*lua.LTable)
|
||||
if !ok {
|
||||
return rpc.NewResponse(map[string]any{
|
||||
"responsible-node": h.cs.UUID32,
|
||||
}, req.ID)
|
||||
return rpc.NewResponse(nil, req.ID)
|
||||
}
|
||||
|
||||
tag := sessionTbl.RawGetString("__gosally_internal")
|
||||
tag := sessionTbl.RawGetString("__seed")
|
||||
if tag.Type() != lua.LTString || tag.String() != fmt.Sprint(seed) {
|
||||
llog.Debug("stock session module is not imported: wrong seed", slog.String("script", path))
|
||||
return rpc.NewResponse(map[string]any{
|
||||
"responsible-node": h.cs.UUID32,
|
||||
}, req.ID)
|
||||
return rpc.NewResponse(nil, req.ID)
|
||||
}
|
||||
|
||||
outVal := sessionTbl.RawGetString("response")
|
||||
@@ -380,39 +614,28 @@ func (h *HandlerV1) handleLUA(sid string, r *http.Request, req *rpc.RPCRequest,
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
}
|
||||
|
||||
if errVal := outTbl.RawGetString("error"); errVal != lua.LNil {
|
||||
if scriptDataTable, ok := outTbl.RawGetString("__script_data").(*lua.LTable); ok {
|
||||
switch __exit {
|
||||
case 1:
|
||||
if errTbl, ok := scriptDataTable.RawGetString("error").(*lua.LTable); ok {
|
||||
llog.Debug("catch error table", slog.String("script", path))
|
||||
if errTbl, ok := errVal.(*lua.LTable); ok {
|
||||
code := rpc.ErrInternalError
|
||||
message := rpc.ErrInternalErrorS
|
||||
data := make(map[string]any)
|
||||
if c := errTbl.RawGetString("code"); c.Type() == lua.LTNumber {
|
||||
code = int(c.(lua.LNumber))
|
||||
}
|
||||
if msg := errTbl.RawGetString("message"); msg.Type() == lua.LTString {
|
||||
message = msg.String()
|
||||
}
|
||||
rawData := errTbl.RawGetString("data")
|
||||
|
||||
if tbl, ok := rawData.(*lua.LTable); ok {
|
||||
tbl.ForEach(func(k, v lua.LValue) { data[k.String()] = ConvertLuaTypesToGolang(v) })
|
||||
} else {
|
||||
llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message))
|
||||
return rpc.NewError(code, message, ConvertLuaTypesToGolang(rawData), req.ID)
|
||||
}
|
||||
llog.Error("the script terminated with an error", slog.String("code", strconv.Itoa(code)), slog.String("message", message))
|
||||
data := ConvertLuaTypesToGolang(errTbl.RawGetString("data"))
|
||||
llog.Error("the script terminated with an error", slog.Int("code", code), slog.String("message", message), slog.Any("data", data))
|
||||
return rpc.NewError(code, message, data, req.ID)
|
||||
}
|
||||
return rpc.NewError(rpc.ErrInternalError, rpc.ErrInternalErrorS, nil, req.ID)
|
||||
case 0:
|
||||
resVal := ConvertLuaTypesToGolang(scriptDataTable.RawGetString("result"))
|
||||
return rpc.NewResponse(resVal, req.ID)
|
||||
}
|
||||
|
||||
resultVal := outTbl.RawGetString("result")
|
||||
payload := make(map[string]any)
|
||||
if tbl, ok := resultVal.(*lua.LTable); ok {
|
||||
tbl.ForEach(func(k, v lua.LValue) { payload[k.String()] = ConvertLuaTypesToGolang(v) })
|
||||
} else {
|
||||
payload["message"] = ConvertLuaTypesToGolang(resultVal)
|
||||
}
|
||||
payload["responsible-node"] = h.cs.UUID32
|
||||
return rpc.NewResponse(payload, req.ID)
|
||||
return rpc.NewResponse(nil, req.ID)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package sv1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
@@ -17,19 +19,56 @@ func ConvertLuaTypesToGolang(value lua.LValue) any {
|
||||
case lua.LTTable:
|
||||
tbl := value.(*lua.LTable)
|
||||
|
||||
var arr []any
|
||||
maxIdx := 0
|
||||
isArray := true
|
||||
tbl.ForEach(func(key, val lua.LValue) {
|
||||
if key.Type() != lua.LTNumber {
|
||||
|
||||
var isNumeric = false
|
||||
tbl.ForEach(func(key, _ lua.LValue) {
|
||||
var numKey lua.LValue
|
||||
var ok bool
|
||||
switch key.Type() {
|
||||
case lua.LTString:
|
||||
numKey, ok = key.(lua.LString)
|
||||
if !ok {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
case lua.LTNumber:
|
||||
numKey, ok = key.(lua.LNumber)
|
||||
if !ok {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
isNumeric = true
|
||||
}
|
||||
|
||||
num, err := strconv.Atoi(numKey.String())
|
||||
if err != nil {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
if num < 1 {
|
||||
isArray = false
|
||||
return
|
||||
}
|
||||
if num > maxIdx {
|
||||
maxIdx = num
|
||||
}
|
||||
arr = append(arr, ConvertLuaTypesToGolang(val))
|
||||
})
|
||||
|
||||
if isArray {
|
||||
arr := make([]any, maxIdx)
|
||||
if isNumeric {
|
||||
for i := 1; i <= maxIdx; i++ {
|
||||
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetInt(i))
|
||||
}
|
||||
} else {
|
||||
for i := 1; i <= maxIdx; i++ {
|
||||
arr[i-1] = ConvertLuaTypesToGolang(tbl.RawGetString(strconv.Itoa(i)))
|
||||
}
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
tbl.ForEach(func(key, val lua.LValue) {
|
||||
result[key.String()] = ConvertLuaTypesToGolang(val)
|
||||
@@ -44,34 +83,44 @@ func ConvertLuaTypesToGolang(value lua.LValue) any {
|
||||
}
|
||||
|
||||
func ConvertGolangTypesToLua(L *lua.LState, val any) lua.LValue {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return lua.LString(v)
|
||||
case bool:
|
||||
return lua.LBool(v)
|
||||
case int:
|
||||
return lua.LNumber(float64(v))
|
||||
case int64:
|
||||
return lua.LNumber(float64(v))
|
||||
case float32:
|
||||
return lua.LNumber(float64(v))
|
||||
case float64:
|
||||
return lua.LNumber(v)
|
||||
case []any:
|
||||
tbl := L.NewTable()
|
||||
for i, item := range v {
|
||||
tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, item))
|
||||
}
|
||||
return tbl
|
||||
case map[string]any:
|
||||
tbl := L.NewTable()
|
||||
for key, value := range v {
|
||||
tbl.RawSetString(key, ConvertGolangTypesToLua(L, value))
|
||||
}
|
||||
return tbl
|
||||
case nil:
|
||||
if val == nil {
|
||||
return lua.LNil
|
||||
default:
|
||||
return lua.LString(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(val)
|
||||
rt := rv.Type()
|
||||
|
||||
switch rt.Kind() {
|
||||
case reflect.String:
|
||||
return lua.LString(rv.String())
|
||||
case reflect.Bool:
|
||||
return lua.LBool(rv.Bool())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return lua.LNumber(rv.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return lua.LNumber(rv.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return lua.LNumber(rv.Float())
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
tbl := L.NewTable()
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
tbl.RawSetInt(i+1, ConvertGolangTypesToLua(L, rv.Index(i).Interface()))
|
||||
}
|
||||
return tbl
|
||||
|
||||
case reflect.Map:
|
||||
if rt.Key().Kind() == reflect.String {
|
||||
tbl := L.NewTable()
|
||||
for _, key := range rv.MapKeys() {
|
||||
val := rv.MapIndex(key)
|
||||
tbl.RawSetString(key.String(), ConvertGolangTypesToLua(L, val.Interface()))
|
||||
}
|
||||
return tbl
|
||||
}
|
||||
|
||||
default:
|
||||
return lua.LString(fmt.Sprintf("%v", val))
|
||||
}
|
||||
return lua.LString(fmt.Sprintf("%v", val))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user