From fdc1921f5dcb5ac6beef8d1c9b1b7c53f514cce5 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Mon, 19 Aug 2024 18:48:18 +0200 Subject: [PATCH] feat: add user info endpoint to support more oidc clients --- .../internal/bootstrap/router_bootstrap.go | 2 +- .../internal/controller/oidc_controller.go | 66 +++++++++++---- .../controller/well_known_controller.go | 1 + backend/internal/model/oidc.go | 1 + backend/internal/service/jwt_service.go | 81 +++++++++++-------- backend/internal/service/oidc_service.go | 70 ++++++++++++---- 6 files changed, 155 insertions(+), 66 deletions(-) diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 008cdc60..ac54bd68 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -47,7 +47,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { // Set up API routes apiGroup := r.Group("/api") controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService) - controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService) + controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService) controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService) controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService) diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index c0ea358a..6896f1c6 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -10,14 +10,16 @@ import ( "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" "strconv" + "strings" ) -func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService) { - oc := &OidcController{OidcService: oidcService} +func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) { + oc := &OidcController{oidcService: oidcService, jwtService: jwtService} group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler) group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler) group.POST("/oidc/token", oc.createIDTokenHandler) + group.GET("/oidc/userinfo", oc.userInfoHandler) group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler) group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler) @@ -33,7 +35,8 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt } type OidcController struct { - OidcService *service.OidcService + oidcService *service.OidcService + jwtService *service.JwtService } func (oc *OidcController) authorizeHandler(c *gin.Context) { @@ -43,7 +46,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) { return } - code, err := oc.OidcService.Authorize(parsedBody, c.GetString("userID")) + code, err := oc.oidcService.Authorize(parsedBody, c.GetString("userID")) if err != nil { if errors.Is(err, common.ErrOidcMissingAuthorization) { utils.HandlerError(c, http.StatusForbidden, err.Error()) @@ -63,7 +66,7 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { return } - code, err := oc.OidcService.AuthorizeNewClient(parsedBody, c.GetString("userID")) + code, err := oc.oidcService.AuthorizeNewClient(parsedBody, c.GetString("userID")) if err != nil { utils.UnknownHandlerError(c, err) return @@ -80,7 +83,20 @@ func (oc *OidcController) createIDTokenHandler(c *gin.Context) { return } - idToken, err := oc.OidcService.CreateIDToken(body) + clientID := body.ClientID + clientSecret := body.ClientSecret + + // Client id and secret can also be passed over the Authorization header + if clientID == "" || clientSecret == "" { + var ok bool + clientID, clientSecret, ok = c.Request.BasicAuth() + if !ok { + utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided") + return + } + } + + idToken, accessToken, err := oc.oidcService.CreateTokens(body.Code, body.GrantType, clientID, clientSecret) if err != nil { if errors.Is(err, common.ErrOidcGrantTypeNotSupported) || errors.Is(err, common.ErrOidcMissingClientCredentials) || @@ -93,12 +109,30 @@ func (oc *OidcController) createIDTokenHandler(c *gin.Context) { return } - c.JSON(http.StatusOK, gin.H{"id_token": idToken}) + c.JSON(http.StatusOK, gin.H{"id_token": idToken, "access_token": accessToken, "token_type": "Bearer"}) +} + +func (oc *OidcController) userInfoHandler(c *gin.Context) { + token := strings.Split(c.GetHeader("Authorization"), " ")[1] + jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token) + if err != nil { + utils.HandlerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error()) + return + } + userID := jwtClaims.Subject + clientId := jwtClaims.Audience[0] + claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, claims) } func (oc *OidcController) getClientHandler(c *gin.Context) { clientId := c.Param("id") - client, err := oc.OidcService.GetClient(clientId) + client, err := oc.oidcService.GetClient(clientId) if err != nil { utils.UnknownHandlerError(c, err) return @@ -112,7 +146,7 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) { pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10")) searchTerm := c.Query("search") - clients, pagination, err := oc.OidcService.ListClients(searchTerm, page, pageSize) + clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize) if err != nil { utils.UnknownHandlerError(c, err) return @@ -131,7 +165,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) { return } - client, err := oc.OidcService.CreateClient(input, c.GetString("userID")) + client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) if err != nil { utils.UnknownHandlerError(c, err) return @@ -141,7 +175,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) { } func (oc *OidcController) deleteClientHandler(c *gin.Context) { - err := oc.OidcService.DeleteClient(c.Param("id")) + err := oc.oidcService.DeleteClient(c.Param("id")) if err != nil { utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") return @@ -157,7 +191,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) { return } - client, err := oc.OidcService.UpdateClient(c.Param("id"), input) + client, err := oc.oidcService.UpdateClient(c.Param("id"), input) if err != nil { utils.UnknownHandlerError(c, err) return @@ -167,7 +201,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) { } func (oc *OidcController) createClientSecretHandler(c *gin.Context) { - secret, err := oc.OidcService.CreateClientSecret(c.Param("id")) + secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) if err != nil { utils.UnknownHandlerError(c, err) return @@ -177,7 +211,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) { } func (oc *OidcController) getClientLogoHandler(c *gin.Context) { - imagePath, mimeType, err := oc.OidcService.GetClientLogo(c.Param("id")) + imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) if err != nil { utils.UnknownHandlerError(c, err) return @@ -194,7 +228,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { return } - err = oc.OidcService.UpdateClientLogo(c.Param("id"), file) + err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) if err != nil { if errors.Is(err, common.ErrFileTypeNotSupported) { utils.HandlerError(c, http.StatusBadRequest, err.Error()) @@ -208,7 +242,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { } func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { - err := oc.OidcService.DeleteClientLogo(c.Param("id")) + err := oc.oidcService.DeleteClientLogo(c.Param("id")) if err != nil { utils.UnknownHandlerError(c, err) return diff --git a/backend/internal/controller/well_known_controller.go b/backend/internal/controller/well_known_controller.go index 17fb0391..efe1d8f8 100644 --- a/backend/internal/controller/well_known_controller.go +++ b/backend/internal/controller/well_known_controller.go @@ -34,6 +34,7 @@ func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) { "issuer": appUrl, "authorization_endpoint": appUrl + "/authorize", "token_endpoint": appUrl + "/api/oidc/token", + "userinfo_endpoint": appUrl + "/api/oidc/userinfo", "jwks_uri": appUrl + "/.well-known/jwks.json", "scopes_supported": []string{"openid", "profile", "email"}, "claims_supported": []string{"sub", "given_name", "family_name", "email", "preferred_username"}, diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 2b5e45ee..96f29bf5 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -8,6 +8,7 @@ import ( type UserAuthorizedOidcClient struct { Scope string UserID string `json:"userId" gorm:"primary_key;"` + User User ClientID string `json:"clientId" gorm:"primary_key;"` Client OidcClient diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index c47508dd..8abff14a 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -18,7 +18,6 @@ import ( "path/filepath" "slices" "strconv" - "strings" "time" ) @@ -54,7 +53,6 @@ type AccessTokenJWTClaims struct { type JWK struct { Kty string `json:"kty"` Use string `json:"use"` - Kid string `json:"kid"` Alg string `json:"alg"` N string `json:"n"` E string `json:"e"` @@ -89,37 +87,6 @@ func (s *JwtService) loadOrGenerateKeys() error { return nil } -func (s *JwtService) GenerateIDToken(user model.User, clientID string, scope string, nonce string) (string, error) { - profileClaims := map[string]interface{}{ - "given_name": user.FirstName, - "family_name": user.LastName, - "email": user.Email, - "preferred_username": user.Username, - } - - claims := jwt.MapClaims{ - "sub": user.ID, - "aud": clientID, - "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), - "iat": jwt.NewNumericDate(time.Now()), - } - - if nonce != "" { - claims["nonce"] = nonce - } - if strings.Contains(scope, "profile") { - for k, v := range profileClaims { - claims[k] = v - } - } - if strings.Contains(scope, "email") { - claims["email"] = user.Email - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - return token.SignedString(s.privateKey) -} - func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value) claim := AccessTokenJWTClaims{ @@ -154,6 +121,53 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaim return claims, nil } +func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) { + claims := jwt.MapClaims{ + "aud": clientID, + "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + "iat": jwt.NewNumericDate(time.Now()), + "iss": common.EnvConfig.AppURL, + } + + for k, v := range userClaims { + claims[k] = v + } + + if nonce != "" { + claims["nonce"] = nonce + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + return token.SignedString(s.privateKey) +} +func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { + claim := jwt.RegisteredClaims{ + Subject: user.ID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Audience: jwt.ClaimStrings{clientID}, + Issuer: common.EnvConfig.AppURL, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) + return token.SignedString(s.privateKey) +} + +func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + return s.publicKey, nil + }) + if err != nil || !token.Valid { + return nil, errors.New("couldn't handle this token") + } + + claims, isValid := token.Claims.(*jwt.RegisteredClaims) + if !isValid { + return nil, errors.New("can't parse claims") + } + + return claims, nil +} + // GetJWK returns the JSON Web Key (JWK) for the public key. func (s *JwtService) GetJWK() (JWK, error) { if s.publicKey == nil { @@ -163,7 +177,6 @@ func (s *JwtService) GetJWK() (JWK, error) { jwk := JWK{ Kty: "RSA", Use: "sig", - Kid: "1", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()), diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index ba7c66bc..7a02ce3c 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm" "mime/multipart" "os" + "strings" "time" ) @@ -54,46 +55,50 @@ func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce) } -func (s *OidcService) CreateIDToken(req model.OidcIdTokenDto) (string, error) { - if req.GrantType != "authorization_code" { - return "", common.ErrOidcGrantTypeNotSupported +func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) { + if grantType != "authorization_code" { + return "", "", common.ErrOidcGrantTypeNotSupported } - clientID := req.ClientID - clientSecret := req.ClientSecret - if clientID == "" || clientSecret == "" { - return "", common.ErrOidcMissingClientCredentials + return "", "", common.ErrOidcMissingClientCredentials } var client model.OidcClient if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { - return "", err + return "", "", err } err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) if err != nil { - return "", common.ErrOidcClientSecretInvalid + return "", "", common.ErrOidcClientSecretInvalid } var authorizationCodeMetaData model.OidcAuthorizationCode - err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", req.Code).Error + err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error if err != nil { - return "", common.ErrOidcInvalidAuthorizationCode + return "", "", common.ErrOidcInvalidAuthorizationCode } if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) { - return "", common.ErrOidcInvalidAuthorizationCode + return "", "", common.ErrOidcInvalidAuthorizationCode } - idToken, err := s.jwtService.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce) + userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) if err != nil { - return "", err + return "", "", err } + idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce) + if err != nil { + return "", "", err + } + + accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID) + s.db.Delete(&authorizationCodeMetaData) - return idToken, nil + return idToken, accessToken, nil } func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) { @@ -259,6 +264,41 @@ func (s *OidcService) DeleteClientLogo(clientID string) error { return nil } +func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) { + var authorizedOidcClient model.UserAuthorizedOidcClient + if err := s.db.Preload("User").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil { + return nil, err + } + + user := authorizedOidcClient.User + scope := authorizedOidcClient.Scope + + claims := map[string]interface{}{ + "sub": user.ID, + } + + if strings.Contains(scope, "email") { + claims["email"] = user.Email + } + + profileClaims := map[string]interface{}{ + "given_name": user.FirstName, + "family_name": user.LastName, + "preferred_username": user.Username, + } + + if strings.Contains(scope, "profile") { + for k, v := range profileClaims { + claims[k] = v + } + } + if strings.Contains(scope, "email") { + claims["email"] = user.Email + } + + return claims, nil +} + func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) { randomString, err := utils.GenerateRandomAlphanumericString(32) if err != nil {