Compare commits
31 Commits
a9da570877
...
beba3cfb4b
| Author | SHA1 | Date | |
|---|---|---|---|
| beba3cfb4b | |||
| 0f966fa17e | |||
| 7546d1bece | |||
| 45f4c76ff5 | |||
| 73343fd57b | |||
| 6c9f8bcec0 | |||
| f65150cec3 | |||
| 99fd0f5776 | |||
| 524749b329 | |||
| c80f7932b4 | |||
| e2b92f8ba1 | |||
| a1f6c1ffa9 | |||
| 7e581d99f5 | |||
| ad980ee600 | |||
| 438bed8f13 | |||
| e9b7f8ca17 | |||
| ae1e5600ae | |||
| 44d39db701 | |||
| adf61a4d1d | |||
| 97253ee9c7 | |||
| 4ae85c73bb | |||
| 16b6b292c6 | |||
| 6f4657caff | |||
| 53761db1e0 | |||
| 603f007c63 | |||
| 597000f222 | |||
| 3b74f5c43d | |||
| 8de6a9212a | |||
| 64dad6619e | |||
| cdde811e72 | |||
| 8836ea2673 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,3 +2,5 @@ bin/
|
|||||||
config.yaml
|
config.yaml
|
||||||
*.sqlite3
|
*.sqlite3
|
||||||
panic.log
|
panic.log
|
||||||
|
testdata/
|
||||||
|
secret/
|
||||||
@@ -3,35 +3,217 @@
|
|||||||
package api_auth
|
package api_auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/auth"
|
||||||
"git.oblat.lv/alex/triggerssmith/internal/config"
|
"git.oblat.lv/alex/triggerssmith/internal/config"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/server"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func setRefreshCookie(w http.ResponseWriter, token string, ttl time.Duration, secure bool) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: token,
|
||||||
|
Path: "/api/auth/refresh",
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: int(ttl.Seconds()),
|
||||||
|
Secure: secure,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type authHandler struct {
|
type authHandler struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
a *auth.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustRoute(config *config.Config) func(chi.Router) {
|
func MustRoute(config *config.Config, authService *auth.Service) func(chi.Router) {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
panic("config is nil")
|
panic("config is nil")
|
||||||
}
|
}
|
||||||
|
if authService == nil {
|
||||||
|
panic("authService is nil")
|
||||||
|
}
|
||||||
h := &authHandler{
|
h := &authHandler{
|
||||||
cfg: config,
|
cfg: config,
|
||||||
|
a: authService,
|
||||||
}
|
}
|
||||||
return func(r chi.Router) {
|
return func(r chi.Router) {
|
||||||
r.Get("/login", h.handleLogin)
|
r.Get("/getUserData", h.handleGetUserData) // legacy support
|
||||||
r.Get("/logout", h.handleLogout)
|
|
||||||
r.Get("/me", h.handleMe)
|
r.Post("/register", h.handleRegister)
|
||||||
r.Get("/revoke", h.handleRevoke)
|
r.Post("/login", h.handleLogin)
|
||||||
|
r.Post("/logout", h.handleLogout) // !requires authentication
|
||||||
|
r.Post("/refresh", h.handleRefresh) // !requires authentication
|
||||||
|
|
||||||
|
r.Get("/me", h.handleMe) // !requires authentication
|
||||||
|
r.Get("/get-user-data", h.handleGetUserData)
|
||||||
|
|
||||||
|
r.Post("/revoke", h.handleRevoke) // not implemented
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *authHandler) handleLogin(w http.ResponseWriter, r *http.Request) {}
|
type registerRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
func (h *authHandler) handleLogout(w http.ResponseWriter, r *http.Request) {}
|
type registerResponse struct {
|
||||||
|
UserID int64 `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
}
|
||||||
|
|
||||||
func (h *authHandler) handleMe(w http.ResponseWriter, r *http.Request) {}
|
func (h *authHandler) handleRegister(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req registerRequest
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid request payload", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (h *authHandler) handleRevoke(w http.ResponseWriter, r *http.Request) {}
|
user, err := h.a.Register(req.Username, req.Email, req.Password)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Registration failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w).Encode(registerResponse{
|
||||||
|
UserID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type loginRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type loginResponse struct {
|
||||||
|
Token string `json:"accessToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req loginRequest
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid request payload", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := h.a.Login(req.Username, req.Password)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Authentication failed", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setRefreshCookie(w, tokens.Refresh, h.cfg.Auth.RefreshTokenTTL, false)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w).Encode(loginResponse{Token: tokens.Access})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims, err := h.a.AuthenticateRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rjti := claims.(jwt.MapClaims)["rjti"].(string)
|
||||||
|
err = h.a.Logout(rjti)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to logout, taking cookie anyways", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: -1,
|
||||||
|
Path: "/api/users/refresh",
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type meResponse struct {
|
||||||
|
UserID int64 `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) handleMe(w http.ResponseWriter, r *http.Request) {
|
||||||
|
refresh_token_cookie, err := r.Cookie("refresh_token")
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, err := h.a.ValidateRefreshToken(refresh_token_cookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user, err := h.a.Get("id", fmt.Sprint(userID))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to get user", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w).Encode(meResponse{
|
||||||
|
UserID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Email: user.Email,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetUserDataResponse meResponse
|
||||||
|
|
||||||
|
func (h *authHandler) handleGetUserData(w http.ResponseWriter, r *http.Request) {
|
||||||
|
by := r.URL.Query().Get("by")
|
||||||
|
value := r.URL.Query().Get("value")
|
||||||
|
if value == "" {
|
||||||
|
value = r.URL.Query().Get(by)
|
||||||
|
}
|
||||||
|
user, err := h.a.Get(by, value)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to get user", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w).Encode(meResponse{
|
||||||
|
UserID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Email: user.Email,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) handleRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
|
server.NotImplemented(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) handleRefresh(w http.ResponseWriter, r *http.Request) {
|
||||||
|
server.NotImplemented(w)
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.oblat.lv/alex/triggerssmith/api/auth"
|
api_auth "git.oblat.lv/alex/triggerssmith/api/auth"
|
||||||
"git.oblat.lv/alex/triggerssmith/api/block"
|
api_block "git.oblat.lv/alex/triggerssmith/api/block"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/auth"
|
||||||
"git.oblat.lv/alex/triggerssmith/internal/config"
|
"git.oblat.lv/alex/triggerssmith/internal/config"
|
||||||
"git.oblat.lv/alex/triggerssmith/internal/vars"
|
"git.oblat.lv/alex/triggerssmith/internal/vars"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@@ -20,13 +21,27 @@ type Router struct {
|
|||||||
r chi.Router
|
r chi.Router
|
||||||
|
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
|
||||||
|
authService *auth.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouter(cfg *config.Config) *Router {
|
type RouterDependencies struct {
|
||||||
|
AuthService *auth.Service
|
||||||
|
Configuration *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRouter(deps RouterDependencies) *Router {
|
||||||
|
if deps.AuthService == nil {
|
||||||
|
panic("AuthService is required")
|
||||||
|
}
|
||||||
|
if deps.Configuration == nil {
|
||||||
|
panic("Configuration is required")
|
||||||
|
}
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
return &Router{
|
return &Router{
|
||||||
r: r,
|
r: r,
|
||||||
cfg: cfg,
|
cfg: deps.Configuration,
|
||||||
|
authService: deps.AuthService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +74,7 @@ func (r *Router) MustRoute() chi.Router {
|
|||||||
|
|
||||||
r.r.Route("/api", func(api chi.Router) {
|
r.r.Route("/api", func(api chi.Router) {
|
||||||
api.Route("/block", api_block.MustRoute(r.cfg))
|
api.Route("/block", api_block.MustRoute(r.cfg))
|
||||||
api.Route("/auth", api_auth.MustRoute(r.cfg))
|
api.Route("/users", api_auth.MustRoute(r.cfg, r.authService))
|
||||||
})
|
})
|
||||||
|
|
||||||
r.r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
r.r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
109
cmd/serve.go
109
cmd/serve.go
@@ -13,16 +13,23 @@ import (
|
|||||||
|
|
||||||
"git.oblat.lv/alex/triggerssmith/api"
|
"git.oblat.lv/alex/triggerssmith/api"
|
||||||
application "git.oblat.lv/alex/triggerssmith/internal/app"
|
application "git.oblat.lv/alex/triggerssmith/internal/app"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/auth"
|
||||||
"git.oblat.lv/alex/triggerssmith/internal/config"
|
"git.oblat.lv/alex/triggerssmith/internal/config"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/jwt"
|
||||||
"git.oblat.lv/alex/triggerssmith/internal/server"
|
"git.oblat.lv/alex/triggerssmith/internal/server"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/token"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/user"
|
||||||
"git.oblat.lv/alex/triggerssmith/internal/vars"
|
"git.oblat.lv/alex/triggerssmith/internal/vars"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var optsServeCmd = struct {
|
var optsServeCmd = struct {
|
||||||
ConfigPath *string
|
ConfigPath *string
|
||||||
Debug *bool
|
Debug *bool
|
||||||
HideGreetings *bool
|
HideGreetings *bool
|
||||||
|
NoPIDFile *bool
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
// // simple middleware for request logging
|
// // simple middleware for request logging
|
||||||
@@ -99,10 +106,8 @@ var serveCmd = &cobra.Command{
|
|||||||
fmt.Fprintf(f, "Panic: %v\n", r)
|
fmt.Fprintf(f, "Panic: %v\n", r)
|
||||||
f.Write(stack)
|
f.Write(stack)
|
||||||
f.WriteString("\n\n")
|
f.WriteString("\n\n")
|
||||||
|
slog.Error("Application panicked: the stack is flushed to disk", slog.Any("error", r))
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Error("Application panicked: the stack is flushed to disk", slog.Any("error", r))
|
|
||||||
|
|
||||||
os.Exit(-1)
|
os.Exit(-1)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -114,13 +119,17 @@ var serveCmd = &cobra.Command{
|
|||||||
slog.SetDefault(slog.New(slog.NewTextHandler(cmd.OutOrStdout(), &slog.HandlerOptions{Level: slog.LevelInfo})))
|
slog.SetDefault(slog.New(slog.NewTextHandler(cmd.OutOrStdout(), &slog.HandlerOptions{Level: slog.LevelInfo})))
|
||||||
}
|
}
|
||||||
|
|
||||||
pid := os.Getpid()
|
if !*optsServeCmd.NoPIDFile {
|
||||||
slog.Debug("Starting server", slog.Int("pid", pid))
|
pid := os.Getpid()
|
||||||
if err := writePID(vars.PID_PATH); err != nil {
|
slog.Debug("Starting server", slog.Int("pid", pid))
|
||||||
panic(err)
|
if err := writePID(vars.PID_PATH); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
slog.Debug("created pid file", slog.String("path", vars.PID_PATH))
|
||||||
|
defer os.Remove(vars.PID_PATH)
|
||||||
|
} else {
|
||||||
|
slog.Warn("Starting server without PID file as requested by --no-pidfile flag: this may complicate process management")
|
||||||
}
|
}
|
||||||
slog.Debug("created pid file", slog.String("path", vars.PID_PATH))
|
|
||||||
defer os.Remove(vars.PID_PATH)
|
|
||||||
|
|
||||||
// load config
|
// load config
|
||||||
slog.Debug("Reading configuration", slog.String("path", *optsServeCmd.ConfigPath))
|
slog.Debug("Reading configuration", slog.String("path", *optsServeCmd.ConfigPath))
|
||||||
@@ -140,16 +149,81 @@ var serveCmd = &cobra.Command{
|
|||||||
app.LoadConfiguration(cfg)
|
app.LoadConfiguration(cfg)
|
||||||
|
|
||||||
srv := app.Server()
|
srv := app.Server()
|
||||||
//mux := http.NewServeMux()
|
|
||||||
|
|
||||||
// static files
|
// Services initialization
|
||||||
// staticPath := cfg.Server.StaticFilesPath
|
var jwtSigner jwt.Signer
|
||||||
// slog.Debug("Setting up static file server", slog.String("path", staticPath))
|
// TODO: support more signing algorithms
|
||||||
// fs := http.FileServer(http.Dir(staticPath))
|
// : support hot config reload for signing alg and secret
|
||||||
// mux.Handle("/static/", http.StripPrefix("/static/", fs))
|
switch cfg.Auth.SignAlg {
|
||||||
// handler := loggingMiddleware(mux)
|
case "HS256":
|
||||||
|
secretBytes, err := os.ReadFile(cfg.Auth.HMACSecretPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to read HMAC secret file", slog.String("path", cfg.Auth.HMACSecretPath), slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwtSigner = jwt.NewHMACSigner(secretBytes)
|
||||||
|
default:
|
||||||
|
slog.Error("Unsupported JWT signing algorithm", slog.String("alg", cfg.Auth.SignAlg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwtService := jwt.NewService(jwtSigner)
|
||||||
|
|
||||||
router := api.NewRouter(cfg)
|
err = os.MkdirAll(cfg.Data.DataPath, 0755)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to create data directory", slog.String("path", cfg.Data.DataPath), slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenDb, err := gorm.Open(sqlite.Open(filepath.Join(cfg.Data.DataPath, "tokens.sqlite3")), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to open token database", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = tokenDb.AutoMigrate(&token.Token{})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to migrate token database", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenStore, err := token.NewSQLiteTokenStore(tokenDb)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to create token store", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenService, err := token.NewTokenService(&cfg.Auth, tokenStore)
|
||||||
|
|
||||||
|
userDb, err := gorm.Open(sqlite.Open(filepath.Join(cfg.Data.DataPath, "users.sqlite3")), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to open user database", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = userDb.AutoMigrate(&user.User{})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to migrate user database", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userStore, err := user.NewGormUserStore(userDb)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to create user store", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userService, err := user.NewService(userStore)
|
||||||
|
|
||||||
|
authService, err := auth.NewAuthService(auth.AuthServiceDependencies{
|
||||||
|
Configuration: cfg,
|
||||||
|
|
||||||
|
JWTService: jwtService,
|
||||||
|
UserService: userService,
|
||||||
|
TokenService: tokenService,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to create auth service", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
router := api.NewRouter(api.RouterDependencies{
|
||||||
|
AuthService: authService,
|
||||||
|
Configuration: cfg,
|
||||||
|
})
|
||||||
|
|
||||||
srv.SetHandler(router.MustRoute())
|
srv.SetHandler(router.MustRoute())
|
||||||
srv.Init()
|
srv.Init()
|
||||||
@@ -211,5 +285,6 @@ func init() {
|
|||||||
optsServeCmd.Debug = serveCmd.Flags().BoolP("debug", "d", false, "Enable debug logs")
|
optsServeCmd.Debug = serveCmd.Flags().BoolP("debug", "d", false, "Enable debug logs")
|
||||||
optsServeCmd.ConfigPath = serveCmd.Flags().StringP("config", "c", "config.yaml", "Path to configuration file")
|
optsServeCmd.ConfigPath = serveCmd.Flags().StringP("config", "c", "config.yaml", "Path to configuration file")
|
||||||
optsServeCmd.HideGreetings = serveCmd.Flags().BoolP("hide-greetings", "g", false, "Hide the welcome message and version when starting the server")
|
optsServeCmd.HideGreetings = serveCmd.Flags().BoolP("hide-greetings", "g", false, "Hide the welcome message and version when starting the server")
|
||||||
|
optsServeCmd.NoPIDFile = serveCmd.Flags().BoolP("no-pidfile", "p", false, "Do not write a PID file")
|
||||||
rootCmd.AddCommand(serveCmd)
|
rootCmd.AddCommand(serveCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
8
go.mod
8
go.mod
@@ -7,10 +7,14 @@ require (
|
|||||||
github.com/spf13/cobra v1.10.1
|
github.com/spf13/cobra v1.10.1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require golang.org/x/crypto v0.46.0 // indirect
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
github.com/go-chi/chi/v5 v5.2.3 // indirect
|
github.com/go-chi/chi/v5 v5.2.3 // indirect
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
@@ -24,8 +28,8 @@ require (
|
|||||||
github.com/spf13/viper v1.21.0 // indirect
|
github.com/spf13/viper v1.21.0 // indirect
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
golang.org/x/sys v0.29.0 // indirect
|
golang.org/x/sys v0.39.0 // indirect
|
||||||
golang.org/x/text v0.28.0 // indirect
|
golang.org/x/text v0.32.0 // indirect
|
||||||
gorm.io/driver/sqlite v1.6.0 // indirect
|
gorm.io/driver/sqlite v1.6.0 // indirect
|
||||||
gorm.io/gorm v1.31.1 // indirect
|
gorm.io/gorm v1.31.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
10
go.sum
10
go.sum
@@ -11,8 +11,12 @@ github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
|||||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
@@ -53,10 +57,16 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8
|
|||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
|
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||||
|
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||||
|
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||||
|
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||||
|
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
246
internal/auth/service.go
Normal file
246
internal/auth/service.go
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/config"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/jwt"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/token"
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/user"
|
||||||
|
ejwt "github.com/golang-jwt/jwt/v5"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tokens struct {
|
||||||
|
Access string
|
||||||
|
Refresh string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
services struct {
|
||||||
|
jwt *jwt.Service
|
||||||
|
user *user.Service
|
||||||
|
token *token.Service
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthServiceDependencies struct {
|
||||||
|
Configuration *config.Config
|
||||||
|
|
||||||
|
JWTService *jwt.Service
|
||||||
|
UserService *user.Service
|
||||||
|
TokenService *token.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthService(deps AuthServiceDependencies) (*Service, error) {
|
||||||
|
if deps.Configuration == nil {
|
||||||
|
return nil, fmt.Errorf("config is nil")
|
||||||
|
}
|
||||||
|
if deps.JWTService == nil {
|
||||||
|
return nil, fmt.Errorf("jwt service is nil")
|
||||||
|
}
|
||||||
|
if deps.UserService == nil {
|
||||||
|
return nil, fmt.Errorf("user service is nil")
|
||||||
|
}
|
||||||
|
if deps.TokenService == nil {
|
||||||
|
return nil, fmt.Errorf("token service is nil")
|
||||||
|
}
|
||||||
|
return &Service{
|
||||||
|
cfg: deps.Configuration,
|
||||||
|
services: struct {
|
||||||
|
jwt *jwt.Service
|
||||||
|
user *user.Service
|
||||||
|
token *token.Service
|
||||||
|
}{
|
||||||
|
jwt: deps.JWTService,
|
||||||
|
user: deps.UserService,
|
||||||
|
token: deps.TokenService,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Users
|
||||||
|
|
||||||
|
func (s *Service) Get(by, value string) (*user.User, error) {
|
||||||
|
return s.services.user.GetBy(by, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register creates a new user with the given username, email, and password.
|
||||||
|
// Password is hashed before storing.
|
||||||
|
// Returns the created user or an error.
|
||||||
|
func (s *Service) Register(username, email, password string) (*user.User, error) {
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &user.User{
|
||||||
|
Username: username,
|
||||||
|
Email: email,
|
||||||
|
Password: string(hashedPassword),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.services.user.Create(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login authenticates a user with the given username and password.
|
||||||
|
// Returns access and refresh tokens if successful.
|
||||||
|
func (s *Service) Login(username, password string) (*Tokens, error) {
|
||||||
|
user, err := s.services.user.GetBy("username", username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get user by username: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid password: %w", err)
|
||||||
|
}
|
||||||
|
refreshToken, rjti, err := s.services.jwt.Generate(s.cfg.Auth.RefreshTokenTTL, ejwt.MapClaims{
|
||||||
|
"sub": user.ID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
accessToken, _, err := s.services.jwt.Generate(s.cfg.Auth.AccessTokenTTL, ejwt.MapClaims{
|
||||||
|
"sub": user.ID,
|
||||||
|
"rjti": rjti,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
return &Tokens{Access: accessToken, Refresh: refreshToken}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout revokes the refresh token identified by the given rjti.
|
||||||
|
func (s *Service) Logout(rjti string) error {
|
||||||
|
return s.services.token.RevokeByRefreshDefault(rjti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access tokens
|
||||||
|
|
||||||
|
// ValidateAccessToken validates the given access token string.
|
||||||
|
// Returns the user ID (sub claim) if valid, or an error.
|
||||||
|
func (s *Service) ValidateAccessToken(tokenStr string) (int64, error) {
|
||||||
|
claims, _, err := s.services.jwt.Validate(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to validate access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
isRevoked, err := s.services.token.IsRevoked(claims["rjti"].(string))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to check if token is revoked: %w", err)
|
||||||
|
}
|
||||||
|
if isRevoked {
|
||||||
|
return 0, fmt.Errorf("token is revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := claims["sub"].(float64)
|
||||||
|
return int64(sub), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh tokens
|
||||||
|
|
||||||
|
// RefreshTokens validates the given refresh token and issues new access and refresh tokens.
|
||||||
|
// Returns the new access and refresh tokens or an error.
|
||||||
|
func (s *Service) RefreshTokens(refreshTokenStr string) (*Tokens, error) {
|
||||||
|
claims, rjti, err := s.services.jwt.Validate(refreshTokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
isRevoked, err := s.services.token.IsRevoked(rjti)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if token is revoked: %w", err)
|
||||||
|
}
|
||||||
|
if isRevoked {
|
||||||
|
return nil, fmt.Errorf("refresh token is revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := claims["sub"].(float64)
|
||||||
|
|
||||||
|
newRefreshToken, newRjti, err := s.services.jwt.Generate(s.cfg.Auth.RefreshTokenTTL, ejwt.MapClaims{
|
||||||
|
"sub": sub,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate new refresh token: %w", err)
|
||||||
|
}
|
||||||
|
newAccessToken, _, err := s.services.jwt.Generate(s.cfg.Auth.AccessTokenTTL, ejwt.MapClaims{
|
||||||
|
"sub": sub,
|
||||||
|
"rjti": newRjti,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate new access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke the old refresh token
|
||||||
|
if err := s.services.token.RevokeByRefreshDefault(rjti); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to revoke old refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Tokens{Access: newAccessToken, Refresh: newRefreshToken}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRefreshToken validates the given refresh token string.
|
||||||
|
// Returns user id and error.
|
||||||
|
func (s *Service) ValidateRefreshToken(tokenStr string) (int64, error) {
|
||||||
|
claims, _, err := s.services.jwt.Validate(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to validate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
isRevoked, err := s.services.token.IsRevoked(claims["jti"].(string))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to check if token is revoked: %w", err)
|
||||||
|
}
|
||||||
|
if isRevoked {
|
||||||
|
return 0, fmt.Errorf("refresh token is revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := claims["sub"].(float64)
|
||||||
|
return int64(sub), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeRefresh revokes the refresh token identified by the given token string.
|
||||||
|
func (s *Service) RevokeRefresh(token string) error {
|
||||||
|
_, rjti, err := s.services.jwt.Validate(token)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to validate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.services.token.RevokeByRefreshDefault(rjti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRefreshRevoked checks if the refresh token identified by the given token string is revoked.
|
||||||
|
func (s *Service) IsRefreshRevoked(token string) (bool, error) {
|
||||||
|
_, rjti, err := s.services.jwt.Validate(token)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to validate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.services.token.IsRevoked(rjti)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) AuthenticateRequest(r *http.Request) (ejwt.Claims, error) {
|
||||||
|
header := r.Header.Get("Authorization")
|
||||||
|
if header == "" {
|
||||||
|
return nil, fmt.Errorf("token is missing")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(header, "Bearer ") {
|
||||||
|
return nil, fmt.Errorf("token is missing")
|
||||||
|
}
|
||||||
|
tokenString := strings.TrimPrefix(header, "Bearer ")
|
||||||
|
tokenClaims, _, err := s.services.jwt.Validate(tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tokenClaims, nil
|
||||||
|
}
|
||||||
@@ -31,9 +31,22 @@ type FuncConfig struct {
|
|||||||
FunctionDir string `mapstructure:"func_dir"`
|
FunctionDir string `mapstructure:"func_dir"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Auth struct {
|
||||||
|
SignAlg string `mapstructure:"sign_alg"`
|
||||||
|
HMACSecretPath string `mapstructure:"hmac_secret_path"`
|
||||||
|
RefreshTokenTTL time.Duration `mapstructure:"refresh_token_ttl"`
|
||||||
|
AccessTokenTTL time.Duration `mapstructure:"access_token_ttl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Data struct {
|
||||||
|
DataPath string `mapstructure:"data_dir"`
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Functions FuncConfig `mapstructure:"functions"`
|
Functions FuncConfig `mapstructure:"functions"`
|
||||||
|
Auth Auth `mapstructure:"auth"`
|
||||||
|
Data Data `mapstructure:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var configPath atomic.Value // string
|
var configPath atomic.Value // string
|
||||||
@@ -48,7 +61,14 @@ var defaults = map[string]any{
|
|||||||
"server.block.enabled": true,
|
"server.block.enabled": true,
|
||||||
"server.block.block_dir": "./blocks",
|
"server.block.block_dir": "./blocks",
|
||||||
|
|
||||||
|
"data.data_dir": "./data",
|
||||||
|
|
||||||
"functions.func_dir": "./functions",
|
"functions.func_dir": "./functions",
|
||||||
|
|
||||||
|
"auth.refresh_token_ttl": 24 * time.Hour,
|
||||||
|
"auth.access_token_ttl": 15 * time.Minute,
|
||||||
|
"auth.sign_alg": "HS256",
|
||||||
|
"auth.hmac_secret_path": "./secret/hmac_secret",
|
||||||
}
|
}
|
||||||
|
|
||||||
func read(cfg *Config) error {
|
func read(cfg *Config) error {
|
||||||
|
|||||||
28
internal/jwt/parse.go
Normal file
28
internal/jwt/parse.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Parse(
|
||||||
|
tokenStr string,
|
||||||
|
method jwt.SigningMethod,
|
||||||
|
key any,
|
||||||
|
) (jwt.Claims, error) {
|
||||||
|
t, err := jwt.Parse(tokenStr, func(tok *jwt.Token) (any, error) {
|
||||||
|
if tok.Method.Alg() != method.Alg() {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method")
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// check validity twice: invalid token may return nil error
|
||||||
|
if !t.Valid {
|
||||||
|
return nil, fmt.Errorf("invalid token")
|
||||||
|
}
|
||||||
|
return t.Claims, nil
|
||||||
|
}
|
||||||
48
internal/jwt/service.go
Normal file
48
internal/jwt/service.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"maps"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
signer Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(signer Signer) *Service {
|
||||||
|
return &Service{
|
||||||
|
signer: signer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate creates a new JWT token for a given user ID and
|
||||||
|
// returns the token string along with its JTI(JWT IDentifier).
|
||||||
|
func (s *Service) Generate(ttl time.Duration, extraClaims jwt.MapClaims) (string, string, error) {
|
||||||
|
jti := uuid.NewString()
|
||||||
|
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"jti": jti,
|
||||||
|
"exp": time.Now().Add(ttl).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
maps.Copy(claims, extraClaims)
|
||||||
|
|
||||||
|
token, err := s.signer.Sign(claims)
|
||||||
|
return token, jti, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate verifies the JWT token and extracts the claims and JTI(JWT IDentifier).
|
||||||
|
// Returns claims, jti, and error if any.
|
||||||
|
func (s *Service) Validate(token string) (jwt.MapClaims, string, error) {
|
||||||
|
claims, err := s.signer.Verify(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
jti := claims.(jwt.MapClaims)["jti"].(string)
|
||||||
|
|
||||||
|
return claims.(jwt.MapClaims), jti, nil
|
||||||
|
}
|
||||||
8
internal/jwt/signer.go
Normal file
8
internal/jwt/signer.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import "github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
type Signer interface {
|
||||||
|
Sign(claims jwt.Claims) (string, error)
|
||||||
|
Verify(token string) (jwt.Claims, error)
|
||||||
|
}
|
||||||
20
internal/jwt/signer_HS256.go
Normal file
20
internal/jwt/signer_HS256.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import "github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
type HMACSigner struct {
|
||||||
|
secret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHMACSigner(secret []byte) *HMACSigner {
|
||||||
|
return &HMACSigner{secret: secret}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HMACSigner) Sign(claims jwt.Claims) (string, error) {
|
||||||
|
t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
return t.SignedString(s.secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HMACSigner) Verify(tokenStr string) (jwt.Claims, error) {
|
||||||
|
return Parse(tokenStr, jwt.SigningMethodHS256, s.secret)
|
||||||
|
}
|
||||||
7
internal/server/notimpl.go
Normal file
7
internal/server/notimpl.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func NotImplemented(w http.ResponseWriter) {
|
||||||
|
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
41
internal/token/service.go
Normal file
41
internal/token/service.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenStore interface {
|
||||||
|
revoke(tokenID string, expiresAt time.Time) error
|
||||||
|
isRevoked(tokenID string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
cfg *config.Auth
|
||||||
|
store TokenStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTokenService(cfg *config.Auth, store TokenStore) (*Service, error) {
|
||||||
|
if store == nil {
|
||||||
|
return nil, fmt.Errorf("store is nil")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("config is nil")
|
||||||
|
}
|
||||||
|
return &Service{cfg: cfg, store: store}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Revoke(jti string, exp time.Time) error {
|
||||||
|
return s.store.revoke(jti, exp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) RevokeByRefreshDefault(jti string) error {
|
||||||
|
expiryTime := time.Now().Add(-time.Duration(s.cfg.RefreshTokenTTL))
|
||||||
|
return s.store.revoke(jti, expiryTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) IsRevoked(jti string) (bool, error) {
|
||||||
|
return s.store.isRevoked(jti)
|
||||||
|
}
|
||||||
45
internal/token/store_sqlite.go
Normal file
45
internal/token/store_sqlite.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SQLiteTokenStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
TokenID string `gorm:"primaryKey"`
|
||||||
|
UserID int64 `gorm:"index"`
|
||||||
|
Expiration time.Time `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSQLiteTokenStore creates a new SQLiteTokenStore with the given GORM DB instance.
|
||||||
|
// Actually can be used for any GORM-supported database.
|
||||||
|
func NewSQLiteTokenStore(db *gorm.DB) (*SQLiteTokenStore, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, fmt.Errorf("db is nil")
|
||||||
|
}
|
||||||
|
return &SQLiteTokenStore{
|
||||||
|
db: db,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteTokenStore) revoke(tokenID string, expiresAt time.Time) error {
|
||||||
|
return s.db.Create(&Token{
|
||||||
|
TokenID: tokenID,
|
||||||
|
Expiration: expiresAt,
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteTokenStore) isRevoked(tokenID string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
err := s.db.Model(&Token{}).Where("token_id = ?", tokenID).Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
71
internal/token/store_sqlite_test.go
Normal file
71
internal/token/store_sqlite_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.oblat.lv/alex/triggerssmith/internal/config"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dbPath := filepath.Join("testdata", "tokens.db")
|
||||||
|
|
||||||
|
_ = os.Remove(dbPath)
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(&Token{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteTokenStore_RevokeAndCheck(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
|
||||||
|
store, err := NewSQLiteTokenStore(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create store: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Auth{
|
||||||
|
RefreshTokenTTL: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
service, err := NewTokenService(cfg, store)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jti := "test-token-123"
|
||||||
|
exp := time.Now().Add(time.Hour)
|
||||||
|
|
||||||
|
revoked, err := service.IsRevoked(jti)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("isRevoked failed: %v", err)
|
||||||
|
}
|
||||||
|
if revoked {
|
||||||
|
t.Fatalf("token should NOT be revoked initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := service.Revoke(jti, exp); err != nil {
|
||||||
|
t.Fatalf("revoke failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
revoked, err = service.IsRevoked(jti)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("isRevoked failed: %v", err)
|
||||||
|
}
|
||||||
|
if !revoked {
|
||||||
|
t.Fatalf("token should be revoked")
|
||||||
|
}
|
||||||
|
}
|
||||||
45
internal/user/gorm_store.go
Normal file
45
internal/user/gorm_store.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GormUserStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGormUserStore(db *gorm.DB) (*GormUserStore, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, fmt.Errorf("db is nil")
|
||||||
|
}
|
||||||
|
return &GormUserStore{
|
||||||
|
db: db,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GormUserStore) Create(user *User) error {
|
||||||
|
return s.db.Create(user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search returns a user by username or id or email
|
||||||
|
func (s *GormUserStore) GetBy(by, value string) (*User, error) {
|
||||||
|
if by != "username" && by != "id" && by != "email" {
|
||||||
|
return nil, fmt.Errorf("unsuppored field %s", by)
|
||||||
|
}
|
||||||
|
var user User
|
||||||
|
err := s.db.Where(fmt.Sprintf("%s = ?", by), value).First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GormUserStore) Update(user *User) error {
|
||||||
|
return s.db.Save(user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GormUserStore) Delete(id int64) error {
|
||||||
|
return s.db.Delete(&User{}, id).Error
|
||||||
|
}
|
||||||
11
internal/user/model.go
Normal file
11
internal/user/model.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import "gorm.io/gorm"
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int64 `gorm:"primaryKey"`
|
||||||
|
Username string `gorm:"uniqueIndex;not null"`
|
||||||
|
Email string `gorm:"uniqueIndex;not null"`
|
||||||
|
Password string `gorm:"not null"`
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
|
}
|
||||||
32
internal/user/service.go
Normal file
32
internal/user/service.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
store UserCRUD
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(store UserCRUD) (*Service, error) {
|
||||||
|
if store == nil {
|
||||||
|
return nil, fmt.Errorf("store is nil")
|
||||||
|
}
|
||||||
|
return &Service{
|
||||||
|
store: store,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Create(user *User) error {
|
||||||
|
return s.store.Create(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) GetBy(by, value string) (*User, error) {
|
||||||
|
return s.store.GetBy(by, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Update(user *User) error {
|
||||||
|
return s.store.Update(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Delete(id int64) error {
|
||||||
|
return s.store.Delete(id)
|
||||||
|
}
|
||||||
8
internal/user/store.go
Normal file
8
internal/user/store.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
type UserCRUD interface {
|
||||||
|
Create(user *User) error
|
||||||
|
GetBy(by, value string) (*User, error)
|
||||||
|
Update(user *User) error
|
||||||
|
Delete(id int64) error
|
||||||
|
}
|
||||||
84
internal/user/user_test.go
Normal file
84
internal/user/user_test.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dbPath := filepath.Join("testdata", "users.db")
|
||||||
|
|
||||||
|
_ = os.Remove(dbPath)
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.AutoMigrate(&User{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsersCRUD(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
|
||||||
|
store, err := NewGormUserStore(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create store: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := NewService(store)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &User{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Password: "password123",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := service.Create(user); err != nil {
|
||||||
|
t.Fatalf("failed to create user: %v", err)
|
||||||
|
}
|
||||||
|
// retrieved, err := service.GetByID(user.ID)
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatalf("failed to get user by ID: %v", err)
|
||||||
|
// }
|
||||||
|
// if retrieved.Username != user.Username {
|
||||||
|
// t.Fatalf("expected username %s, got %s", user.Username, retrieved.Username)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// retrievedByUsername, err := service.GetByUsername(user.Username)
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatalf("failed to get user by username: %v", err)
|
||||||
|
// }
|
||||||
|
// if retrievedByUsername.Email != user.Email {
|
||||||
|
// t.Fatalf("expected email %s, got %s", user.Email, retrievedByUsername.Email)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// user.Email = "newemail@example.com"
|
||||||
|
// if err := service.Update(user); err != nil {
|
||||||
|
// t.Fatalf("failed to update user: %v", err)
|
||||||
|
// }
|
||||||
|
// retrieved, err = service.GetByID(user.ID)
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatalf("failed to get user by ID: %v", err)
|
||||||
|
// }
|
||||||
|
// if retrieved.Email != user.Email {
|
||||||
|
// t.Fatalf("expected email %s, got %s", user.Email, retrieved.Email)
|
||||||
|
// }
|
||||||
|
err = service.Delete(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to delete user: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user