diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 510b073d..023e8fcb 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -136,13 +136,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { } // Validate that code is provided for authorization_code grant type - if input.GrantType == "authorization_code" && input.Code == "" { + if input.GrantType == service.GrantTypeAuthorizationCode && input.Code == "" { _ = c.Error(&common.OidcMissingAuthorizationCodeError{}) return } // Validate that refresh_token is provided for refresh_token grant type - if input.GrantType == "refresh_token" && input.RefreshToken == "" { + if input.GrantType == service.GrantTypeRefreshToken && input.RefreshToken == "" { _ = c.Error(&common.OidcMissingRefreshTokenError{}) return } @@ -152,8 +152,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth() } - idToken, accessToken, refreshToken, expiresIn, err := - oc.oidcService.CreateTokens(c.Request.Context(), input) + tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input) switch { case errors.Is(err, &common.OidcAuthorizationPendingError{}): @@ -171,23 +170,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { return } - response := dto.OidcTokenResponseDto{ - AccessToken: accessToken, - TokenType: "Bearer", - ExpiresIn: expiresIn, - } - - // Include ID token only for authorization_code grant - if idToken != "" { - response.IdToken = idToken - } - - // Include refresh token if generated - if refreshToken != "" { - response.RefreshToken = refreshToken - } - - c.JSON(http.StatusOK, response) + c.JSON(http.StatusOK, dto.OidcTokenResponseDto{ + AccessToken: tokens.AccessToken, + TokenType: "Bearer", + ExpiresIn: int(tokens.ExpiresIn.Seconds()), + IdToken: tokens.IdToken, // May be empty + RefreshToken: tokens.RefreshToken, // May be empty + }) } // userInfoHandler godoc diff --git a/backend/internal/controller/well_known_controller.go b/backend/internal/controller/well_known_controller.go index 8c75d02b..6c45a8b4 100644 --- a/backend/internal/controller/well_known_controller.go +++ b/backend/internal/controller/well_known_controller.go @@ -77,7 +77,7 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) { "introspection_endpoint": appUrl + "/api/oidc/introspect", "device_authorization_endpoint": appUrl + "/api/oidc/device/authorize", "jwks_uri": appUrl + "/.well-known/jwks.json", - "grant_types_supported": []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"}, + "grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode}, "scopes_supported": []string{"openid", "profile", "email", "groups"}, "claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"}, "response_types_supported": []string{"code", "id_token"}, diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 2bd11f6d..1dea7785 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -15,17 +15,22 @@ import ( "strings" "time" - "gorm.io/gorm/clause" - "github.com/lestrrat-go/jwx/v3/jwt" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" + "gorm.io/gorm/clause" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" "github.com/pocket-id/pocket-id/backend/internal/utils" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" +) + +const ( + GrantTypeAuthorizationCode = "authorization_code" + GrantTypeRefreshToken = "refresh_token" + GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" ) type OidcService struct { @@ -167,139 +172,158 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode return isAllowedToAuthorize } -func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (idToken string, accessToken string, newRefreshToken string, exp int, err error) { +type CreatedTokens struct { + IdToken string + AccessToken string + RefreshToken string + ExpiresIn time.Duration +} + +func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) { switch input.GrantType { - case "authorization_code": - return s.createTokenFromAuthorizationCode(ctx, input.Code, input.ClientID, input.ClientSecret, input.CodeVerifier) - case "refresh_token": - accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, input.RefreshToken, input.ClientID, input.ClientSecret) - return "", accessToken, newRefreshToken, exp, err - case "urn:ietf:params:oauth:grant-type:device_code": - return s.createTokenFromDeviceCode(ctx, input.DeviceCode, input.ClientID, input.ClientSecret) + case GrantTypeAuthorizationCode: + return s.createTokenFromAuthorizationCode(ctx, input) + case GrantTypeRefreshToken: + return s.createTokenFromRefreshToken(ctx, input) + case GrantTypeDeviceCode: + return s.createTokenFromDeviceCode(ctx, input) default: - return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{} + return CreatedTokens{}, &common.OidcGrantTypeNotSupportedError{} } } -func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, clientID string, clientSecret string) (idToken string, accessToken string, refreshToken string, exp int, err error) { +func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) { tx := s.db.Begin() defer func() { tx.Rollback() }() - _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) + _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } // Get the device authorization from database with explicit query conditions var deviceAuth model.OidcDeviceCode - if err := tx.WithContext(ctx).Preload("User").Where("device_code = ? AND client_id = ?", deviceCode, clientID).First(&deviceAuth).Error; err != nil { + err = tx. + WithContext(ctx). + Preload("User"). + Where("device_code = ? AND client_id = ?", input.DeviceCode, input.ClientID). + First(&deviceAuth). + Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return "", "", "", 0, &common.OidcInvalidDeviceCodeError{} + return CreatedTokens{}, &common.OidcInvalidDeviceCodeError{} } - return "", "", "", 0, err + return CreatedTokens{}, err } // Check if device code has expired if time.Now().After(deviceAuth.ExpiresAt.ToTime()) { - return "", "", "", 0, &common.OidcDeviceCodeExpiredError{} + return CreatedTokens{}, &common.OidcDeviceCodeExpiredError{} } // Check if device code has been authorized if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil { - return "", "", "", 0, &common.OidcAuthorizationPendingError{} + return CreatedTokens{}, &common.OidcAuthorizationPendingError{} } // Get user claims for the ID token - ensure UserID is not nil if deviceAuth.UserID == nil { - return "", "", "", 0, &common.OidcAuthorizationPendingError{} + return CreatedTokens{}, &common.OidcAuthorizationPendingError{} } - userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, clientID, tx) + userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, input.ClientID, tx) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } // Explicitly use the input clientID for the audience claim to ensure consistency - idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, "") + idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "") if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } - refreshToken, err = s.createRefreshToken(ctx, clientID, *deviceAuth.UserID, deviceAuth.Scope, tx) + refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } - accessToken, err = s.jwtService.GenerateOauthAccessToken(deviceAuth.User, clientID) + accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } // Delete the used device code - if err := tx.WithContext(ctx).Delete(&deviceAuth).Error; err != nil { - return "", "", "", 0, err + err = tx.WithContext(ctx).Delete(&deviceAuth).Error + if err != nil { + return CreatedTokens{}, err } - if err := tx.Commit().Error; err != nil { - return "", "", "", 0, err + err = tx.Commit().Error + if err != nil { + return CreatedTokens{}, err } - return idToken, accessToken, refreshToken, 3600, nil + return CreatedTokens{ + IdToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: time.Hour, + }, nil } -func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) { +func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) { tx := s.db.Begin() defer func() { tx.Rollback() }() - client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) + client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } var authorizationCodeMetaData model.OidcAuthorizationCode err = tx. WithContext(ctx). Preload("User"). - First(&authorizationCodeMetaData, "code = ?", code). + First(&authorizationCodeMetaData, "code = ?", input.Code). Error if err != nil { - return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} + return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{} } // If the client is public or PKCE is enabled, the code verifier must match the code challenge if client.IsPublic || client.PkceEnabled { - if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) { - return "", "", "", 0, &common.OidcInvalidCodeVerifierError{} + if !s.validateCodeVerifier(input.CodeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) { + return CreatedTokens{}, &common.OidcInvalidCodeVerifierError{} } } - if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { - return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} + if authorizationCodeMetaData.ClientID != input.ClientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { + return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{} } - userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx) + userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, input.ClientID, tx) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } - idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce) + idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } // Generate a refresh token - refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx) + refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } - accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID) + accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID) if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } err = tx. @@ -307,20 +331,25 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code Delete(&authorizationCodeMetaData). Error if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } err = tx.Commit().Error if err != nil { - return "", "", "", 0, err + return CreatedTokens{}, err } - return idToken, accessToken, refreshToken, 3600, nil + return CreatedTokens{ + IdToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: time.Hour, + }, nil } -func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) { - if refreshToken == "" { - return "", "", 0, &common.OidcMissingRefreshTokenError{} +func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) { + if input.RefreshToken == "" { + return CreatedTokens{}, &common.OidcMissingRefreshTokenError{} } tx := s.db.Begin() @@ -328,9 +357,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo tx.Rollback() }() - _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) + _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx) if err != nil { - return "", "", 0, err + return CreatedTokens{}, err } // Verify refresh token @@ -338,31 +367,31 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo err = tx. WithContext(ctx). Preload("User"). - Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())). + Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())). First(&storedRefreshToken). Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return "", "", 0, &common.OidcInvalidRefreshTokenError{} + return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} } - return "", "", 0, err + return CreatedTokens{}, err } // Verify that the refresh token belongs to the provided client - if storedRefreshToken.ClientID != clientID { - return "", "", 0, &common.OidcInvalidRefreshTokenError{} + if storedRefreshToken.ClientID != input.ClientID { + return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} } // Generate a new access token - accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID) + accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID) if err != nil { - return "", "", 0, err + return CreatedTokens{}, err } // Generate a new refresh token and invalidate the old one - newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx) + newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx) if err != nil { - return "", "", 0, err + return CreatedTokens{}, err } // Delete the used refresh token @@ -371,15 +400,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo Delete(&storedRefreshToken). Error if err != nil { - return "", "", 0, err + return CreatedTokens{}, err } err = tx.Commit().Error if err != nil { - return "", "", 0, err + return CreatedTokens{}, err } - return accessToken, newRefreshToken, 3600, nil + return CreatedTokens{ + AccessToken: accessToken, + RefreshToken: newRefreshToken, + ExpiresIn: time.Hour, + }, nil } func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { @@ -1181,9 +1214,12 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID } func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) { + // First, ensure we have a valid client ID if clientID == "" { return model.OidcClient{}, &common.OidcMissingClientCredentialsError{} } + + // Load the OIDC client's configuration var client model.OidcClient err := tx. WithContext(ctx). @@ -1193,10 +1229,16 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clien return model.OidcClient{}, err } - if !client.IsPublic { - if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil { + // If we have a client secret, we validate it + // Otherwise, we require the client to be public + if clientSecret != "" { + err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) + if err != nil { return model.OidcClient{}, &common.OidcClientSecretInvalidError{} } + return client, nil + } else if !client.IsPublic { + return model.OidcClient{}, &common.OidcMissingClientCredentialsError{} } return client, nil