mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-09 23:02:59 +03:00
feat: add user info endpoint to support more oidc clients
This commit is contained in:
@@ -47,7 +47,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
// Set up API routes
|
||||
apiGroup := r.Group("/api")
|
||||
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
|
||||
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService)
|
||||
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
|
||||
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
|
||||
controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService)
|
||||
|
||||
|
||||
@@ -10,14 +10,16 @@ import (
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService) {
|
||||
oc := &OidcController{OidcService: oidcService}
|
||||
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) {
|
||||
oc := &OidcController{oidcService: oidcService, jwtService: jwtService}
|
||||
|
||||
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
|
||||
group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler)
|
||||
group.POST("/oidc/token", oc.createIDTokenHandler)
|
||||
group.GET("/oidc/userinfo", oc.userInfoHandler)
|
||||
|
||||
group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler)
|
||||
group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler)
|
||||
@@ -33,7 +35,8 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
|
||||
}
|
||||
|
||||
type OidcController struct {
|
||||
OidcService *service.OidcService
|
||||
oidcService *service.OidcService
|
||||
jwtService *service.JwtService
|
||||
}
|
||||
|
||||
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
@@ -43,7 +46,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
code, err := oc.OidcService.Authorize(parsedBody, c.GetString("userID"))
|
||||
code, err := oc.oidcService.Authorize(parsedBody, c.GetString("userID"))
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcMissingAuthorization) {
|
||||
utils.HandlerError(c, http.StatusForbidden, err.Error())
|
||||
@@ -63,7 +66,7 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
code, err := oc.OidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
|
||||
code, err := oc.oidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
@@ -80,7 +83,20 @@ func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := oc.OidcService.CreateIDToken(body)
|
||||
clientID := body.ClientID
|
||||
clientSecret := body.ClientSecret
|
||||
|
||||
// Client id and secret can also be passed over the Authorization header
|
||||
if clientID == "" || clientSecret == "" {
|
||||
var ok bool
|
||||
clientID, clientSecret, ok = c.Request.BasicAuth()
|
||||
if !ok {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
idToken, accessToken, err := oc.oidcService.CreateTokens(body.Code, body.GrantType, clientID, clientSecret)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
|
||||
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
|
||||
@@ -93,12 +109,30 @@ func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"id_token": idToken})
|
||||
c.JSON(http.StatusOK, gin.H{"id_token": idToken, "access_token": accessToken, "token_type": "Bearer"})
|
||||
}
|
||||
|
||||
func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
|
||||
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
|
||||
return
|
||||
}
|
||||
userID := jwtClaims.Subject
|
||||
clientId := jwtClaims.Audience[0]
|
||||
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, claims)
|
||||
}
|
||||
|
||||
func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
clientId := c.Param("id")
|
||||
client, err := oc.OidcService.GetClient(clientId)
|
||||
client, err := oc.oidcService.GetClient(clientId)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
@@ -112,7 +146,7 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
||||
searchTerm := c.Query("search")
|
||||
|
||||
clients, pagination, err := oc.OidcService.ListClients(searchTerm, page, pageSize)
|
||||
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
@@ -131,7 +165,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.OidcService.CreateClient(input, c.GetString("userID"))
|
||||
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
@@ -141,7 +175,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||
err := oc.OidcService.DeleteClient(c.Param("id"))
|
||||
err := oc.oidcService.DeleteClient(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||
return
|
||||
@@ -157,7 +191,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.OidcService.UpdateClient(c.Param("id"), input)
|
||||
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
@@ -167,7 +201,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
secret, err := oc.OidcService.CreateClientSecret(c.Param("id"))
|
||||
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
@@ -177,7 +211,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
imagePath, mimeType, err := oc.OidcService.GetClientLogo(c.Param("id"))
|
||||
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
@@ -194,7 +228,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = oc.OidcService.UpdateClientLogo(c.Param("id"), file)
|
||||
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
||||
@@ -208,7 +242,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
err := oc.OidcService.DeleteClientLogo(c.Param("id"))
|
||||
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
|
||||
@@ -34,6 +34,7 @@ func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
|
||||
"issuer": appUrl,
|
||||
"authorization_endpoint": appUrl + "/authorize",
|
||||
"token_endpoint": appUrl + "/api/oidc/token",
|
||||
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
|
||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
||||
"scopes_supported": []string{"openid", "profile", "email"},
|
||||
"claims_supported": []string{"sub", "given_name", "family_name", "email", "preferred_username"},
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
type UserAuthorizedOidcClient struct {
|
||||
Scope string
|
||||
UserID string `json:"userId" gorm:"primary_key;"`
|
||||
User User
|
||||
|
||||
ClientID string `json:"clientId" gorm:"primary_key;"`
|
||||
Client OidcClient
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -54,7 +53,6 @@ type AccessTokenJWTClaims struct {
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Use string `json:"use"`
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
@@ -89,37 +87,6 @@ func (s *JwtService) loadOrGenerateKeys() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateIDToken(user model.User, clientID string, scope string, nonce string) (string, error) {
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"email": user.Email,
|
||||
"preferred_username": user.Username,
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"aud": clientID,
|
||||
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
"iat": jwt.NewNumericDate(time.Now()),
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
if strings.Contains(scope, "profile") {
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
|
||||
claim := AccessTokenJWTClaims{
|
||||
@@ -154,6 +121,53 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaim
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"aud": clientID,
|
||||
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
"iat": jwt.NewNumericDate(time.Now()),
|
||||
"iss": common.EnvConfig.AppURL,
|
||||
}
|
||||
|
||||
for k, v := range userClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
claim := jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: jwt.ClaimStrings{clientID},
|
||||
Issuer: common.EnvConfig.AppURL,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.publicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
||||
func (s *JwtService) GetJWK() (JWK, error) {
|
||||
if s.publicKey == nil {
|
||||
@@ -163,7 +177,6 @@ func (s *JwtService) GetJWK() (JWK, error) {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Kid: "1",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -54,46 +55,50 @@ func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID
|
||||
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateIDToken(req model.OidcIdTokenDto) (string, error) {
|
||||
if req.GrantType != "authorization_code" {
|
||||
return "", common.ErrOidcGrantTypeNotSupported
|
||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
|
||||
if grantType != "authorization_code" {
|
||||
return "", "", common.ErrOidcGrantTypeNotSupported
|
||||
}
|
||||
|
||||
clientID := req.ClientID
|
||||
clientSecret := req.ClientSecret
|
||||
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", common.ErrOidcMissingClientCredentials
|
||||
return "", "", common.ErrOidcMissingClientCredentials
|
||||
}
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return "", common.ErrOidcClientSecretInvalid
|
||||
return "", "", common.ErrOidcClientSecretInvalid
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", req.Code).Error
|
||||
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
||||
if err != nil {
|
||||
return "", common.ErrOidcInvalidAuthorizationCode
|
||||
return "", "", common.ErrOidcInvalidAuthorizationCode
|
||||
}
|
||||
|
||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
|
||||
return "", common.ErrOidcInvalidAuthorizationCode
|
||||
return "", "", common.ErrOidcInvalidAuthorizationCode
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce)
|
||||
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||
|
||||
s.db.Delete(&authorizationCodeMetaData)
|
||||
|
||||
return idToken, nil
|
||||
return idToken, accessToken, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) {
|
||||
@@ -259,6 +264,41 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.Preload("User").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := authorizedOidcClient.User
|
||||
scope := authorizedOidcClient.Scope
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"preferred_username": user.Username,
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "profile") {
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user