diff --git a/internal/auth/service.go b/internal/auth/service.go index 93dc043..aa8285d 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "fmt" "net/http" "strings" @@ -8,7 +9,7 @@ import ( "git.oblat.lv/alex/triggerssmith/internal/config" "git.oblat.lv/alex/triggerssmith/internal/jwt" "git.oblat.lv/alex/triggerssmith/internal/token" - "git.oblat.lv/alex/triggerssmith/internal/user" + user_p "git.oblat.lv/alex/triggerssmith/internal/user" ejwt "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" ) @@ -23,7 +24,7 @@ type Service struct { services struct { jwt *jwt.Service - user *user.Service + user *user_p.Service token *token.Service } } @@ -32,7 +33,7 @@ type AuthServiceDependencies struct { Configuration *config.Config JWTService *jwt.Service - UserService *user.Service + UserService *user_p.Service TokenService *token.Service } @@ -53,7 +54,7 @@ func NewAuthService(deps AuthServiceDependencies) (*Service, error) { cfg: deps.Configuration, services: struct { jwt *jwt.Service - user *user.Service + user *user_p.Service token *token.Service }{ jwt: deps.JWTService, @@ -65,20 +66,20 @@ func NewAuthService(deps AuthServiceDependencies) (*Service, error) { // Users -func (s *Service) Get(by, value string) (*user.User, error) { +func (s *Service) Get(by, value string) (*user_p.User, error) { return s.services.user.GetBy(by, value) } // Register creates a new user with the given username, email, and password. // Password is hashed before storing. // Returns the created user or an error. -func (s *Service) Register(username, email, password string) (*user.User, error) { +func (s *Service) Register(username, email, password string) (*user_p.User, error) { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } - user := &user.User{ + user := &user_p.User{ Username: username, Email: email, Password: string(hashedPassword), @@ -97,12 +98,15 @@ func (s *Service) Register(username, email, password string) (*user.User, error) func (s *Service) Login(username, password string) (*Tokens, error) { user, err := s.services.user.GetBy("username", username) if err != nil { + if err == user_p.ErrUserNotFound { + return nil, ErrInvalidUsername + } return nil, fmt.Errorf("failed to get user by username: %w", err) } err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) if err != nil { - return nil, fmt.Errorf("invalid password: %w", err) + return nil, ErrInvalidPassword } refreshToken, rjti, err := s.services.jwt.Generate(s.cfg.Auth.RefreshTokenTTL, ejwt.MapClaims{ "sub": user.ID, @@ -122,39 +126,46 @@ func (s *Service) Login(username, password string) (*Tokens, error) { // Logout revokes the refresh token identified by the given rjti. func (s *Service) Logout(rjti string) error { - return s.services.token.RevokeByRefreshDefault(rjti) + err := s.services.token.RevokeByRefreshDefault(rjti) + if err != nil { + if errors.Is(err, token.ErrTokenIsRevoked) { + return ErrInvalidToken + } + return fmt.Errorf("failed to revoke token: %w", err) + } + return nil } // Access tokens // ValidateAccessToken validates the given access token string. -// Returns the user ID (sub claim) if valid, or an error. -func (s *Service) ValidateAccessToken(tokenStr string) (int64, error) { +// Returns claims if valid, or an error. +func (s *Service) ValidateAccessToken(tokenStr string) (ejwt.Claims, error) { claims, _, err := s.services.jwt.Validate(tokenStr) if err != nil { - return 0, fmt.Errorf("failed to validate access token: %w", err) + return nil, fmt.Errorf("failed to validate access token: %w", err) } isRevoked, err := s.services.token.IsRevoked(claims["rjti"].(string)) if err != nil { - return 0, fmt.Errorf("failed to check if token is revoked: %w", err) + return nil, fmt.Errorf("failed to check if token is revoked: %w", err) } if isRevoked { - return 0, fmt.Errorf("token is revoked") + return nil, fmt.Errorf("token is revoked") } - sub := claims["sub"].(float64) - return int64(sub), nil + return claims, nil } // Refresh tokens // RefreshTokens validates the given refresh token and issues new access and refresh tokens. // Returns the new access and refresh tokens or an error. +// May return [ErrInvalidToken] if the refresh token is invalid or revoked. func (s *Service) RefreshTokens(refreshTokenStr string) (*Tokens, error) { claims, rjti, err := s.services.jwt.Validate(refreshTokenStr) if err != nil { - return nil, fmt.Errorf("failed to validate refresh token: %w", err) + return nil, errors.Join(ErrInvalidToken, err) } isRevoked, err := s.services.token.IsRevoked(rjti) @@ -162,7 +173,7 @@ func (s *Service) RefreshTokens(refreshTokenStr string) (*Tokens, error) { return nil, fmt.Errorf("failed to check if token is revoked: %w", err) } if isRevoked { - return nil, fmt.Errorf("refresh token is revoked") + return nil, ErrInvalidToken } sub := claims["sub"].(float64) @@ -190,23 +201,22 @@ func (s *Service) RefreshTokens(refreshTokenStr string) (*Tokens, error) { } // ValidateRefreshToken validates the given refresh token string. -// Returns user id and error. -func (s *Service) ValidateRefreshToken(tokenStr string) (int64, error) { +// Returns claims and error. +func (s *Service) ValidateRefreshToken(tokenStr string) (ejwt.Claims, error) { claims, _, err := s.services.jwt.Validate(tokenStr) if err != nil { - return 0, fmt.Errorf("failed to validate refresh token: %w", err) + return nil, fmt.Errorf("failed to validate refresh token: %w", err) } isRevoked, err := s.services.token.IsRevoked(claims["jti"].(string)) if err != nil { - return 0, fmt.Errorf("failed to check if token is revoked: %w", err) + return nil, fmt.Errorf("failed to check if token is revoked: %w", err) } if isRevoked { - return 0, fmt.Errorf("refresh token is revoked") + return nil, fmt.Errorf("refresh token is revoked") } - sub := claims["sub"].(float64) - return int64(sub), nil + return claims, nil } // RevokeRefresh revokes the refresh token identified by the given token string. @@ -232,10 +242,10 @@ func (s *Service) IsRefreshRevoked(token string) (bool, error) { func (s *Service) AuthenticateRequest(r *http.Request) (ejwt.Claims, error) { header := r.Header.Get("Authorization") if header == "" { - return nil, fmt.Errorf("token is missing") + return nil, ErrTokenIsMissing } if !strings.HasPrefix(header, "Bearer ") { - return nil, fmt.Errorf("token is missing") + return nil, ErrTokenIsMissing } tokenString := strings.TrimPrefix(header, "Bearer ") tokenClaims, _, err := s.services.jwt.Validate(tokenString)