feat: JWT bearer assertions for client authentication (#566)

Co-authored-by: Kyle Mendell <ksm@ofkm.us>
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Alessandro (Ale) Segala
2025-06-06 03:23:51 -07:00
committed by GitHub
parent 035b2c022b
commit 05bfe00924
38 changed files with 1464 additions and 293 deletions

1
.gitignore vendored
View File

@@ -10,6 +10,7 @@ node_modules
/frontend/build
/backend/bin
pocket-id
/tests/test-results/*.json
# OS
.DS_Store

View File

@@ -20,7 +20,8 @@ require (
github.com/google/uuid v1.6.0
github.com/hashicorp/go-uuid v1.0.3
github.com/joho/godotenv v1.5.1
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2
github.com/lestrrat-go/jwx/v3 v3.0.1
github.com/mileusna/useragent v1.3.5
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
github.com/stretchr/testify v1.10.0
@@ -32,7 +33,7 @@ require (
go.opentelemetry.io/otel/sdk v1.35.0
go.opentelemetry.io/otel/sdk/metric v1.35.0
go.opentelemetry.io/otel/trace v1.35.0
golang.org/x/crypto v0.36.0
golang.org/x/crypto v0.37.0
golang.org/x/image v0.24.0
golang.org/x/time v0.9.0
gorm.io/driver/postgres v1.5.11
@@ -77,9 +78,8 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
@@ -123,7 +123,7 @@ require (
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/text v0.24.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
google.golang.org/grpc v1.71.0 // indirect

View File

@@ -164,14 +164,14 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 h1:SDxjGoH7qj0nBXVrcrxX8eD94wEnjR+EEuqqmeqQYlY=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2/go.mod h1:Nwo81sMxE0DcvTB+rJyynNhv/DUu2yZErV7sscw9pHE=
github.com/lestrrat-go/jwx/v3 v3.0.1 h1:fH3T748FCMbXoF9UXXNS9i0q6PpYyJZK/rKSbkt2guY=
github.com/lestrrat-go/jwx/v3 v3.0.1/go.mod h1:XP2WqxMOSzHSyf3pfibCcfsLqbomxakAnNqiuaH8nwo=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@@ -309,8 +309,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -377,8 +377,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -3,6 +3,8 @@
package bootstrap
import (
"log"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -14,7 +16,12 @@ import (
func init() {
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
testService := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
if err != nil {
log.Fatalf("failed to initialize test service: %v", err)
return
}
controller.NewTestController(apiGroup, testService)
},
}

View File

@@ -26,15 +26,14 @@ type services struct {
}
// Initializes all services
// The context should be used by services only for initialization, and not for running
func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
svc = &services{}
svc.appConfigService = service.NewAppConfigService(initCtx, db)
svc.appConfigService = service.NewAppConfigService(ctx, db)
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
if err != nil {
return nil, fmt.Errorf("unable to create email service: %w", err)
return nil, fmt.Errorf("failed to create email service: %w", err)
}
svc.geoLiteService = service.NewGeoLiteService(httpClient)
@@ -42,7 +41,12 @@ func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client)
svc.jwtService = service.NewJwtService(svc.appConfigService)
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
svc.customClaimService = service.NewCustomClaimService(db)
svc.oidcService = service.NewOidcService(db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
}
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)

View File

@@ -65,6 +65,11 @@ type OidcClientSecretInvalidError struct{}
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
type OidcClientAssertionInvalidError struct{}
func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" }
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 }
type OidcInvalidAuthorizationCodeError struct{}
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }

View File

@@ -14,6 +14,9 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
testController := &TestController{TestService: testService}
group.POST("/test/reset", testController.resetAndSeedHandler)
group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
group.POST("/externalidp/sign", testController.externalIdPSignToken)
}
type TestController struct {
@@ -21,6 +24,15 @@ type TestController struct {
}
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
var baseURL string
if c.Request.TLS != nil {
baseURL = "https://" + c.Request.Host
} else {
baseURL = "http://" + c.Request.Host
}
skipLdap := c.Query("skip-ldap") == "true"
if err := tc.TestService.ResetDatabase(); err != nil {
_ = c.Error(err)
return
@@ -31,7 +43,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
return
}
if err := tc.TestService.SeedDatabase(); err != nil {
if err := tc.TestService.SeedDatabase(baseURL); err != nil {
_ = c.Error(err)
return
}
@@ -41,17 +53,50 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
return
}
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
_ = c.Error(err)
return
}
if !skipLdap {
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
_ = c.Error(err)
return
}
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
_ = c.Error(err)
return
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
_ = c.Error(err)
return
}
}
tc.TestService.SetJWTKeys()
c.Status(http.StatusNoContent)
}
func (tc *TestController) externalIdPJWKS(c *gin.Context) {
jwks, err := tc.TestService.GetExternalIdPJWKS()
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, jwks)
}
func (tc *TestController) externalIdPSignToken(c *gin.Context) {
var input struct {
Aud string `json:"aud"`
Iss string `json:"iss"`
Sub string `json:"sub"`
}
err := c.ShouldBindJSON(&input)
if err != nil {
_ = c.Error(err)
return
}
token, err := tc.TestService.SignExternalIdPToken(input.Iss, input.Sub, input.Aud)
if err != nil {
_ = c.Error(err)
return
}
c.Writer.WriteString(token)
}

View File

@@ -7,14 +7,14 @@ import (
"net/url"
"strings"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
)
// NewOidcController creates a new controller for OIDC related endpoints
@@ -124,11 +124,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
// @Tags OIDC
// @Produce json
// @Param client_id formData string false "Client ID (if not using Basic Auth)"
// @Param client_secret formData string false "Client secret (if not using Basic Auth)"
// @Param client_secret formData string false "Client secret (if not using Basic Auth or client assertions)"
// @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
// @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
// @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
// @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
// @Param client_assertion formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
// @Param client_assertion_type formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
// @Router /api/oidc/token [post]
func (oc *OidcController) createTokensHandler(c *gin.Context) {
@@ -363,12 +365,12 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
if err != nil {
_ = c.Error(err)
return
}
_ = c.Error(err)
c.JSON(http.StatusOK, clientDto)
}
// listClientsHandler godoc

View File

@@ -62,7 +62,60 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
return nil
}
//nolint:gocognit
func mapField(sourceField reflect.Value, destField reflect.Value) error {
// Handle pointer to struct in source
if sourceField.Kind() == reflect.Ptr && !sourceField.IsNil() {
switch {
case sourceField.Elem().Kind() == reflect.Struct:
switch {
case destField.Kind() == reflect.Struct:
// Map from pointer to struct -> struct
return mapStructInternal(sourceField.Elem(), destField)
case destField.Kind() == reflect.Ptr && destField.CanSet():
// Map from pointer to struct -> pointer to struct
if destField.IsNil() {
destField.Set(reflect.New(destField.Type().Elem()))
}
return mapStructInternal(sourceField.Elem(), destField.Elem())
}
case destField.Kind() == reflect.Ptr &&
destField.CanSet() &&
sourceField.Elem().Type().AssignableTo(destField.Type().Elem()):
// Handle primitive pointer types (e.g., *string to *string)
if destField.IsNil() {
destField.Set(reflect.New(destField.Type().Elem()))
}
destField.Elem().Set(sourceField.Elem())
return nil
case destField.Kind() != reflect.Ptr &&
destField.CanSet() &&
sourceField.Elem().Type().AssignableTo(destField.Type()):
// Handle *T to T conversion for primitive types
destField.Set(sourceField.Elem())
return nil
}
}
// Handle pointer to struct in destination
if destField.Kind() == reflect.Ptr && destField.CanSet() {
switch {
case sourceField.Kind() == reflect.Struct:
// Map from struct -> pointer to struct
if destField.IsNil() {
destField.Set(reflect.New(destField.Type().Elem()))
}
return mapStructInternal(sourceField, destField.Elem())
case !sourceField.IsZero() && sourceField.Type().AssignableTo(destField.Type().Elem()):
// Handle T to *T conversion for primitive types
if destField.IsNil() {
destField.Set(reflect.New(destField.Type().Elem()))
}
destField.Elem().Set(sourceField)
return nil
}
}
switch {
case sourceField.Type() == destField.Type():
destField.Set(sourceField)

View File

@@ -8,10 +8,11 @@ type OidcClientMetaDataDto struct {
type OidcClientDto struct {
OidcClientMetaDataDto
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
Credentials OidcClientCredentialsDto `json:"credentials"`
}
type OidcClientWithAllowedUserGroupsDto struct {
@@ -25,11 +26,23 @@ type OidcClientWithAllowedGroupsCountDto struct {
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
Credentials OidcClientCredentialsDto `json:"credentials"`
}
type OidcClientCredentialsDto struct {
FederatedIdentities []OidcClientFederatedIdentityDto `json:"federatedIdentities,omitempty"`
}
type OidcClientFederatedIdentityDto struct {
Issuer string `json:"issuer"`
Subject string `json:"subject,omitempty"`
Audience string `json:"audience,omitempty"`
JWKS string `json:"jwks,omitempty"`
}
type AuthorizeOidcClientRequestDto struct {
@@ -52,13 +65,15 @@ type AuthorizationRequiredDto struct {
}
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
DeviceCode string `form:"device_code"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"`
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
DeviceCode string `form:"device_code"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"`
ClientAssertion string `form:"client_assertion"`
ClientAssertionType string `form:"client_assertion_type"`
}
type OidcIntrospectDto struct {

View File

@@ -5,8 +5,9 @@ import (
"encoding/json"
"fmt"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type UserAuthorizedOidcClient struct {
@@ -45,6 +46,7 @@ type OidcClient struct {
HasLogo bool `gorm:"-"`
IsPublic bool
PkceEnabled bool
Credentials OidcClientCredentials
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
CreatedByID string
@@ -71,9 +73,49 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
return nil
}
type OidcClientCredentials struct { //nolint:recvcheck
FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"`
}
type OidcClientFederatedIdentity struct {
Issuer string `json:"issuer"`
Subject string `json:"subject,omitempty"`
Audience string `json:"audience,omitempty"`
JWKS string `json:"jwks,omitempty"` // URL of the JWKS
}
func (occ OidcClientCredentials) FederatedIdentityForIssuer(issuer string) (OidcClientFederatedIdentity, bool) {
if issuer == "" {
return OidcClientFederatedIdentity{}, false
}
for _, fi := range occ.FederatedIdentities {
if fi.Issuer == issuer {
return fi, true
}
}
return OidcClientFederatedIdentity{}, false
}
func (occ *OidcClientCredentials) Scan(value any) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, occ)
case string:
return json.Unmarshal([]byte(v), occ)
default:
return fmt.Errorf("unsupported type: %T", value)
}
}
func (occ OidcClientCredentials) Value() (driver.Value, error) {
return json.Marshal(occ)
}
type UrlList []string //nolint:recvcheck
func (cu *UrlList) Scan(value interface{}) error {
func (cu *UrlList) Scan(value any) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, cu)

View File

@@ -29,17 +29,17 @@ type AppConfigService struct {
db *gorm.DB
}
func NewAppConfigService(initCtx context.Context, db *gorm.DB) *AppConfigService {
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(initCtx)
err := service.LoadDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
err = service.initInstanceID(initCtx)
err = service.initInstanceID(ctx)
if err != nil {
log.Fatalf("Failed to initialize instance ID: %v", err)
}

View File

@@ -3,16 +3,10 @@ package service
import (
"sync/atomic"
"testing"
"time"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/stretchr/testify/require"
)
@@ -28,7 +22,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
@@ -42,7 +36,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loads value from config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
@@ -72,7 +66,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("ignores unknown config keys", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
@@ -93,7 +87,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loading config multiple times", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
@@ -135,7 +129,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -171,7 +165,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -195,7 +189,7 @@ func TestLoadDbConfig(t *testing.T) {
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -220,7 +214,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("update multiple values", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -264,7 +258,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("empty value resets to default", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -285,7 +279,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with odd number of arguments", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -301,7 +295,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with invalid key", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -319,7 +313,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -392,7 +386,7 @@ func TestUpdateAppConfig(t *testing.T) {
})
t.Run("empty values reset to defaults", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
@@ -457,7 +451,7 @@ func TestUpdateAppConfig(t *testing.T) {
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
db := newAppConfigTestDatabaseForTest(t)
db := newDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
@@ -475,49 +469,3 @@ func TestUpdateAppConfig(t *testing.T) {
require.ErrorAs(t, err, &uiConfigDisabledErr)
})
}
// Implements gorm's logger.Writer interface
type testLoggerAdapter struct {
t *testing.T
}
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
dbName := utils.CreateSha256Hash(t.Name())
// Connect to a new in-memory SQL database
db, err := gorm.Open(
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
&gorm.Config{
TranslateError: true,
Logger: logger.New(
testLoggerAdapter{t: t},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
Colorful: false,
},
),
})
require.NoError(t, err, "Failed to connect to test database")
// Create the app_config_variables table
err = db.Exec(`
CREATE TABLE app_config_variables
(
key VARCHAR(100) NOT NULL PRIMARY KEY,
value TEXT NOT NULL
)
`).Error
require.NoError(t, err, "Failed to create test config table")
return db
}

View File

@@ -5,6 +5,8 @@ package service
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"fmt"
@@ -16,6 +18,7 @@ import (
"github.com/fxamacker/cbor/v2"
"github.com/go-webauthn/webauthn/protocol"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
@@ -30,14 +33,43 @@ type TestService struct {
jwtService *JwtService
appConfigService *AppConfigService
ldapService *LdapService
externalIdPKey jwk.Key
}
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) *TestService {
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService, ldapService: ldapService}
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) {
s := &TestService{
db: db,
appConfigService: appConfigService,
jwtService: jwtService,
ldapService: ldapService,
}
err := s.initExternalIdP()
if err != nil {
return nil, fmt.Errorf("failed to initialize external IdP: %w", err)
}
return s, nil
}
// Initializes the "external IdP"
// This creates a new "issuing authority" containing a public JWKS
// It also stores the private key internally that will be used to issue JWTs
func (s *TestService) initExternalIdP() error {
// Generate a new ECDSA key
rawKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return fmt.Errorf("failed to generate private key: %w", err)
}
s.externalIdPKey, err = utils.ImportRawKey(rawKey)
if err != nil {
return fmt.Errorf("failed to import private key: %w", err)
}
return nil
}
//nolint:gocognit
func (s *TestService) SeedDatabase() error {
func (s *TestService) SeedDatabase(baseURL string) error {
err := s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{
{
@@ -138,6 +170,26 @@ func (s *TestService) SeedDatabase() error {
userGroups[1],
},
},
{
Base: model.Base{
ID: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
},
Name: "Federated",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURLs: model.UrlList{"http://federated/auth/callback"},
CreatedByID: users[1].ID,
AllowedUserGroups: []model.UserGroup{},
Credentials: model.OidcClientCredentials{
FederatedIdentities: []model.OidcClientFederatedIdentity{
{
Issuer: "https://external-idp.local",
Audience: "api://PocketID",
Subject: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
JWKS: baseURL + "/api/externalidp/jwks.json",
},
},
},
},
}
for _, client := range oidcClients {
if err := tx.Create(&client).Error; err != nil {
@@ -145,16 +197,28 @@ func (s *TestService) SeedDatabase() error {
}
}
authCode := model.OidcAuthorizationCode{
Code: "auth-code",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
authCodes := []model.OidcAuthorizationCode{
{
Code: "auth-code",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
},
{
Code: "federated",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[1].ID,
ClientID: oidcClients[2].ID,
},
}
if err := tx.Create(&authCode).Error; err != nil {
return err
for _, authCode := range authCodes {
if err := tx.Create(&authCode).Error; err != nil {
return err
}
}
refreshToken := model.OidcRefreshToken{
@@ -177,13 +241,22 @@ func (s *TestService) SeedDatabase() error {
return err
}
userAuthorizedClient := model.UserAuthorizedOidcClient{
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
userAuthorizedClients := []model.UserAuthorizedOidcClient{
{
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
},
{
Scope: "openid profile email",
UserID: users[1].ID,
ClientID: oidcClients[2].ID,
},
}
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
return err
for _, userAuthorizedClient := range userAuthorizedClients {
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
return err
}
}
// To generate a new key pair, run the following command:
@@ -405,3 +478,41 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
return nil
}
// GetExternalIdPJWKS returns the JWKS for the "external IdP".
func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
pubKey, err := s.externalIdPKey.PublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
set := jwk.NewSet()
err = set.AddKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to add public key to set: %w", err)
}
return set, nil
}
func (s *TestService) SignExternalIdPToken(iss, sub, aud string) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(sub).
Expiration(now.Add(time.Hour)).
IssuedAt(now).
Issuer(iss).
Audience([]string{aud}).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
alg, _ := s.externalIdPKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.externalIdPKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
}

View File

@@ -4,11 +4,9 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
@@ -372,7 +370,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
EnsureAlgInKey(pubKey)
utils.EnsureAlgInKey(pubKey)
return pubKey, nil
}
@@ -415,27 +413,6 @@ func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
return key, nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
func EnsureAlgInKey(key jwk.Key) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
case jwa.OKP():
// Default to EdDSA for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
}
}
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
@@ -444,27 +421,7 @@ func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
}
// Import the raw key
return importRawKey(rawKey)
}
func importRawKey(rawKey any) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key)
return key, err
return utils.ImportRawKey(rawKey)
}
// SaveKeyJWK saves a JWK to a file
@@ -492,16 +449,6 @@ func SaveKeyJWK(key jwk.Key, path string) error {
return nil
}
// generateRandomKeyID generates a random key ID.
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
// GetIsAdmin returns the value of the "isAdmin" claim in the token
func GetIsAdmin(token jwt.Token) (bool, error) {
if !token.Has(IsAdminClaim) {

View File

@@ -21,6 +21,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
func TestJwtService_Init(t *testing.T) {
@@ -1218,7 +1219,7 @@ func TestTokenTypeValidator(t *testing.T) {
func importKey(t *testing.T, privateKeyRaw any, path string) string {
t.Helper()
privateKey, err := importRawKey(privateKeyRaw)
privateKey, err := utils.ImportRawKey(privateKeyRaw)
require.NoError(t, err, "Failed to import private key")
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))

View File

@@ -3,18 +3,25 @@ package service
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"log/slog"
"mime/multipart"
"net/http"
"os"
"regexp"
"slices"
"strings"
"time"
"github.com/lestrrat-go/httprc/v3"
"github.com/lestrrat-go/httprc/v3/errsink"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
@@ -31,6 +38,8 @@ const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeRefreshToken = "refresh_token"
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
)
type OidcService struct {
@@ -39,16 +48,61 @@ type OidcService struct {
appConfigService *AppConfigService
auditLogService *AuditLogService
customClaimService *CustomClaimService
httpClient *http.Client
jwkCache *jwk.Cache
}
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
return &OidcService{
func NewOidcService(
ctx context.Context,
db *gorm.DB,
jwtService *JwtService,
appConfigService *AppConfigService,
auditLogService *AuditLogService,
customClaimService *CustomClaimService,
) (s *OidcService, err error) {
s = &OidcService{
db: db,
jwtService: jwtService,
appConfigService: appConfigService,
auditLogService: auditLogService,
customClaimService: customClaimService,
}
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
s.jwkCache, err = s.getJWKCache(ctx)
if err != nil {
return nil, err
}
return s, nil
}
func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
// We need to create a custom HTTP client to set a timeout.
client := s.httpClient
if client == nil {
client = &http.Client{
Timeout: 20 * time.Second,
}
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
// Indicates a development-time error
panic("Default transport is not of type *http.Transport")
}
transport := defaultTransport.Clone()
transport.TLSClientConfig.MinVersion = tls.VersionTLS12
client.Transport = transport
}
// Create the JWKS cache
return jwk.NewCache(ctx,
httprc.NewClient(
httprc.WithErrorSink(errsink.NewSlog(slog.Default())),
httprc.WithHTTPClient(client),
),
)
}
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
@@ -198,7 +252,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -279,7 +333,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback()
}()
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -357,7 +411,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -420,7 +474,10 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: clientID,
ClientSecret: clientSecret,
})
if err != nil {
return introspectDto, err
}
@@ -440,33 +497,35 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
introspectDto.Active = true
introspectDto.TokenType = "access_token"
if token.Has("scope") {
var asString string
var asStrings []string
var (
asString string
asStrings []string
)
if err := token.Get("scope", &asString); err == nil {
introspectDto.Scope = asString
} else if err := token.Get("scope", &asStrings); err == nil {
introspectDto.Scope = strings.Join(asStrings, " ")
}
}
if expiration, hasExpiration := token.Expiration(); hasExpiration {
if expiration, ok := token.Expiration(); ok {
introspectDto.Expiration = expiration.Unix()
}
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
if issuedAt, ok := token.IssuedAt(); ok {
introspectDto.IssuedAt = issuedAt.Unix()
}
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
if notBefore, ok := token.NotBefore(); ok {
introspectDto.NotBefore = notBefore.Unix()
}
if subject, hasSubject := token.Subject(); hasSubject {
if subject, ok := token.Subject(); ok {
introspectDto.Subject = subject
}
if audience, hasAudience := token.Audience(); hasAudience {
if audience, ok := token.Audience(); ok {
introspectDto.Audience = audience
}
if issuer, hasIssuer := token.Issuer(); hasIssuer {
if issuer, ok := token.Issuer(); ok {
introspectDto.Issuer = issuer
}
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
if identifier, ok := token.JwtID(); ok {
introspectDto.Identifier = identifier
}
@@ -542,13 +601,9 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{
Name: input.Name,
CallbackURLs: input.CallbackURLs,
LogoutCallbackURLs: input.LogoutCallbackURLs,
CreatedByID: userID,
IsPublic: input.IsPublic,
PkceEnabled: input.PkceEnabled,
CreatedByID: userID,
}
updateOIDCClientModelFromDto(&client, &input)
err := s.db.
WithContext(ctx).
@@ -577,11 +632,7 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return model.OidcClient{}, err
}
client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.LogoutCallbackURLs = input.LogoutCallbackURLs
client.IsPublic = input.IsPublic
client.PkceEnabled = input.IsPublic || input.PkceEnabled
updateOIDCClientModelFromDto(&client, &input)
err = tx.
WithContext(ctx).
@@ -599,6 +650,29 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return client, nil
}
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientCreateDto) {
// Base fields
client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.LogoutCallbackURLs = input.LogoutCallbackURLs
client.IsPublic = input.IsPublic
// PKCE is required for public clients
client.PkceEnabled = input.IsPublic || input.PkceEnabled
// Credentials
if len(input.Credentials.FederatedIdentities) > 0 {
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
for i, fi := range input.Credentials.FederatedIdentities {
client.Credentials.FederatedIdentities[i] = model.OidcClientFederatedIdentity{
Issuer: fi.Issuer,
Audience: fi.Audience,
Subject: fi.Subject,
JWKS: fi.JWKS,
}
}
}
}
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
var client model.OidcClient
err := s.db.
@@ -1079,7 +1153,10 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
}
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: input.ClientID,
ClientSecret: input.ClientSecret,
})
if err != nil {
return nil, err
}
@@ -1305,33 +1382,140 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
return err
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
// First, ensure we have a valid client ID
if clientID == "" {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
if input.ClientID == "" {
return nil, &common.OidcMissingClientCredentialsError{}
}
// Load the OIDC client's configuration
var client model.OidcClient
err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
First(&client, "id = ?", input.ClientID).
Error
if err != nil {
return model.OidcClient{}, err
return nil, err
}
// If we have a client secret, we validate it
// Otherwise, we require the client to be public
if clientSecret != "" {
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
// We have 3 options
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
switch {
// First, if we have a client secret, we validate it
case input.ClientSecret != "":
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(input.ClientSecret))
if err != nil {
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
return nil, &common.OidcClientSecretInvalidError{}
}
return &client, nil
// Next, check if we want to use client assertions from federated identities
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
if err != nil {
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
return nil, &common.OidcClientAssertionInvalidError{}
}
return &client, nil
// There's no credentials
// This is allowed only if the client is public
case client.IsPublic:
return &client, nil
// If we're here, we have no credentials AND the client is not public, so credentials are required
default:
return nil, &common.OidcMissingClientCredentialsError{}
}
}
func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set, err error) {
// Check if we have already registered the URL
if !s.jwkCache.IsRegistered(ctx, url) {
// We set a timeout because otherwise Register will keep trying in case of errors
registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second)
defer registerCancel()
// We need to register the URL
err = s.jwkCache.Register(
registerCtx,
url,
jwk.WithMaxInterval(24*time.Hour),
jwk.WithMinInterval(15*time.Minute),
jwk.WithWaitReady(true),
)
// In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) {
return nil, fmt.Errorf("failed to register JWK set: %w", err)
}
return client, nil
} else if !client.IsPublic {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
}
return client, nil
jwks, err := s.jwkCache.CachedSet(url)
if err != nil {
return nil, fmt.Errorf("failed to get cached JWK set: %w", err)
}
return jwks, nil
}
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
// First, parse the assertion JWT, without validating it, to check the issuer
assertion := []byte(input.ClientAssertion)
insecureToken, err := jwt.ParseInsecure(assertion)
if err != nil {
return fmt.Errorf("failed to parse client assertion JWT: %w", err)
}
issuer, _ := insecureToken.Issuer()
if issuer == "" {
return errors.New("client assertion does not contain an issuer claim")
}
// Ensure that this client is federated with the one that issued the token
ocfi, ok := client.Credentials.FederatedIdentityForIssuer(issuer)
if !ok {
return fmt.Errorf("client assertion is not from an allowed issuer: %s", issuer)
}
// Get the JWK set for the issuer
jwksURL := ocfi.JWKS
if jwksURL == "" {
// Default URL is from the issuer
if strings.HasSuffix(issuer, "/") {
jwksURL = issuer + ".well-known/jwks.json"
} else {
jwksURL = issuer + "/.well-known/jwks.json"
}
}
jwks, err := s.jwkSetForURL(ctx, jwksURL)
if err != nil {
return fmt.Errorf("failed to get JWK set for issuer '%s': %w", issuer, err)
}
// Set default audience and subject if missing
audience := ocfi.Audience
if audience == "" {
// Default to the Pocket ID's URL
audience = common.EnvConfig.AppURL
}
subject := ocfi.Subject
if subject == "" {
// Default to the client ID, per RFC 7523
subject = client.ID
}
// Now re-parse the token with proper validation
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
_, err = jwt.Parse(assertion,
jwt.WithValidate(true),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
jwt.WithAudience(audience),
jwt.WithSubject(subject),
)
if err != nil {
return fmt.Errorf("client assertion is not valid: %w", err)
}
// If we're here, the assertion is valid
return nil
}

View File

@@ -0,0 +1,365 @@
package service
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
)
// generateTestECDSAKey creates an ECDSA key for testing
func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) {
t.Helper()
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
privateJwk, err := jwk.Import(privateKey)
require.NoError(t, err)
err = privateJwk.Set(jwk.KeyIDKey, "test-key-1")
require.NoError(t, err)
err = privateJwk.Set(jwk.AlgorithmKey, "ES256")
require.NoError(t, err)
err = privateJwk.Set("use", "sig")
require.NoError(t, err)
publicJwk, err := jwk.PublicKeyOf(privateJwk)
require.NoError(t, err)
// Create a JWK Set with the public key
jwkSet := jwk.NewSet()
err = jwkSet.AddKey(publicJwk)
require.NoError(t, err)
jwkSetJSON, err := json.Marshal(jwkSet)
require.NoError(t, err)
return privateJwk, jwkSetJSON
}
func TestOidcService_jwkSetForURL(t *testing.T) {
// Generate a test key for JWKS
_, jwkSetJSON1 := generateTestECDSAKey(t)
_, jwkSetJSON2 := generateTestECDSAKey(t)
// Create a mock HTTP client with responses for different URLs
const (
url1 = "https://example.com/.well-known/jwks.json"
url2 = "https://other-issuer.com/jwks"
)
mockResponses := map[string]*http.Response{
//nolint:bodyclose
url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
//nolint:bodyclose
url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
}
httpClient := &http.Client{
Transport: &MockRoundTripper{
Responses: mockResponses,
},
}
// Create the OidcService with our mock client
s := &OidcService{
httpClient: httpClient,
}
var err error
s.jwkCache, err = s.getJWKCache(t.Context())
require.NoError(t, err)
t.Run("Fetches and caches JWK set", func(t *testing.T) {
jwks, err := s.jwkSetForURL(t.Context(), url1)
require.NoError(t, err)
require.NotNil(t, jwks)
// Verify the JWK set contains our key
require.Equal(t, 1, jwks.Len())
})
t.Run("Fails with invalid URL", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
_, err := s.jwkSetForURL(ctx, "https://bad-url.com")
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("Safe for concurrent use", func(t *testing.T) {
const concurrency = 20
// Channel to collect errors
errChan := make(chan error, concurrency)
// Start concurrent requests
for range concurrency {
go func() {
jwks, err := s.jwkSetForURL(t.Context(), url2)
if err != nil {
errChan <- err
return
}
// Verify the JWK set is valid
if jwks == nil || jwks.Len() != 1 {
errChan <- assert.AnError
return
}
errChan <- nil
}()
}
// Check for errors
for range concurrency {
assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors")
}
})
}
func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
const (
federatedClientIssuer = "https://external-idp.com"
federatedClientAudience = "https://pocket-id.com"
federatedClientSubject = "123456abcdef"
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
)
var err error
// Create a test database
db := newDatabaseForTest(t)
// Create two JWKs for testing
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
require.NoError(t, err)
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
require.NoError(t, err)
// Create a mock HTTP client with custom transport to return the JWKS
httpClient := &http.Client{
Transport: &MockRoundTripper{
Responses: map[string]*http.Response{
//nolint:bodyclose
federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)),
//nolint:bodyclose
federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
},
},
}
// Init the OidcService
s := &OidcService{
db: db,
httpClient: httpClient,
}
s.jwkCache, err = s.getJWKCache(t.Context())
require.NoError(t, err)
// Create the test clients
// 1. Confidential client
confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Confidential Client",
CallbackURLs: []string{"https://example.com/callback"},
}, "test-user-id")
require.NoError(t, err)
// Create a client secret for the confidential client
confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID)
require.NoError(t, err)
// 2. Public client
publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Public Client",
CallbackURLs: []string{"https://example.com/callback"},
IsPublic: true,
}, "test-user-id")
require.NoError(t, err)
// 3. Confidential client with federated identity
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Federated Client",
CallbackURLs: []string{"https://example.com/callback"},
Credentials: dto.OidcClientCredentialsDto{
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
{
Issuer: federatedClientIssuer,
Audience: federatedClientAudience,
Subject: federatedClientSubject,
JWKS: federatedClientIssuer + "/jwks.json",
},
{Issuer: federatedClientIssuerDefaults},
},
},
}, "test-user-id")
require.NoError(t, err)
// Test cases for confidential client (using client secret)
t.Run("Confidential client", func(t *testing.T) {
t.Run("Succeeds with valid secret", func(t *testing.T) {
// Test with valid client credentials
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: confidentialClient.ID,
ClientSecret: confidentialSecret,
})
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, confidentialClient.ID, client.ID)
})
t.Run("Fails with invalid secret", func(t *testing.T) {
// Test with invalid client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: confidentialClient.ID,
ClientSecret: "invalid-secret",
})
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
assert.Nil(t, client)
})
t.Run("Fails with missing secret", func(t *testing.T) {
// Test with missing client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: confidentialClient.ID,
})
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
assert.Nil(t, client)
})
})
// Test cases for public client
t.Run("Public client", func(t *testing.T) {
t.Run("Succeeds with no credentials", func(t *testing.T) {
// Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: publicClient.ID,
})
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, publicClient.ID, client.ID)
})
})
// Test cases for federated client using JWT assertion
t.Run("Federated client", func(t *testing.T) {
t.Run("Succeeds with valid JWT", func(t *testing.T) {
// Create JWT for federated identity
token, err := jwt.NewBuilder().
Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}).
Subject(federatedClientSubject).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)).
Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
require.NoError(t, err)
// Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
})
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID)
})
t.Run("Fails with malformed JWT", func(t *testing.T) {
// Test with invalid JWT assertion (just a random string)
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: "invalid.jwt.token",
})
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
assert.Nil(t, client)
})
testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) {
return func(t *testing.T) {
// Populate all claims with valid values
builder := jwt.NewBuilder().
Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}).
Subject(federatedClientSubject).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute))
// Call builderFn to override the claims
builderFn(builder)
token, err := builder.Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
require.NoError(t, err)
// Test with invalid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
})
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
require.Nil(t, client)
}
}
t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Expiration(time.Now().Add(-30 * time.Minute))
}))
t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Issuer("https://bad-issuer.com")
}))
t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Audience([]string{"bad-audience"})
}))
t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Subject("bad-subject")
}))
t.Run("Uses default values for audience and subject", func(t *testing.T) {
// Create JWT for federated identity
token, err := jwt.NewBuilder().
Issuer(federatedClientIssuerDefaults).
Audience([]string{common.EnvConfig.AppURL}).
Subject(federatedClient.ID).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)).
Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults))
require.NoError(t, err)
// Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
})
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID)
})
})
}

View File

@@ -0,0 +1,97 @@
package service
import (
"io"
"net/http"
"strings"
"testing"
"time"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/glebarez/sqlite"
"github.com/golang-migrate/migrate/v4"
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/resources"
)
func newDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
dbName := utils.CreateSha256Hash(t.Name())
// Connect to a new in-memory SQL database
db, err := gorm.Open(
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
&gorm.Config{
TranslateError: true,
Logger: logger.New(
testLoggerAdapter{t: t},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
Colorful: false,
},
),
})
require.NoError(t, err, "Failed to connect to test database")
// Perform migrations with the embedded migrations
sqlDB, err := db.DB()
require.NoError(t, err, "Failed to get sql.DB")
driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{})
require.NoError(t, err, "Failed to create migration driver")
source, err := iofs.New(resources.FS, "migrations/sqlite")
require.NoError(t, err, "Failed to create embedded migration source")
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
require.NoError(t, err, "Failed to create migration instance")
err = m.Up()
require.NoError(t, err, "Failed to perform migrations")
return db
}
// Implements gorm's logger.Writer interface
type testLoggerAdapter struct {
t *testing.T
}
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
type MockRoundTripper struct {
Err error
Responses map[string]*http.Response
}
// RoundTrip implements the http.RoundTripper interface
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Check if we have a specific response for this URL
for url, resp := range m.Responses {
if req.URL.String() == url {
return resp, nil
}
}
return NewMockResponse(http.StatusNotFound, ""), nil
}
// NewMockResponse creates an http.Response with the given status code and body
func NewMockResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}
}

View File

@@ -0,0 +1,69 @@
package utils
import (
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
)
const (
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
)
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
// It also populates additional fields such as the key ID, usage, and alg.
func ImportRawKey(rawKey any) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key)
return key, nil
}
// generateRandomKeyID generates a random key ID.
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
func EnsureAlgInKey(key jwk.Key) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
case jwa.OKP():
// Default to EdDSA for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
}
}

View File

@@ -0,0 +1 @@
ALTER TABLE oidc_clients DROP COLUMN credentials;

View File

@@ -0,0 +1 @@
ALTER TABLE oidc_clients ADD COLUMN credentials JSONB NULL;

View File

@@ -0,0 +1 @@
ALTER TABLE oidc_clients DROP COLUMN credentials;

View File

@@ -0,0 +1 @@
ALTER TABLE oidc_clients ADD COLUMN credentials TEXT NULL;

View File

@@ -348,6 +348,12 @@
"the_device_has_been_authorized": "The device has been authorized.",
"enter_code_displayed_in_previous_step": "Enter the code that was displayed in the previous step.",
"authorize": "Authorize",
"federated_identities": "Federated Identities",
"federated_identities_description": "Using federated identities, you can authenticate OIDC clients using JWT tokens issued by third-party authorities.",
"add_federated_identity": "Add Federated Identity",
"add_another_federated_identity": "Add another federated identity",
"oidc_allowed_group_count": "Allowed Group Count",
"unrestricted": "Unrestricted"
"unrestricted": "Unrestricted",
"show_advanced_options": "Show Advanced Options",
"hide_advanced_options": "Hide Advanced Options"
}

View File

@@ -6,11 +6,23 @@ export type OidcClientMetaData = {
hasLogo: boolean;
};
export type OidcClientFederatedIdentity = {
issuer: string;
subject?: string;
audience?: string;
jwks?: string;
};
export type OidcClientCredentials = {
federatedIdentities: OidcClientFederatedIdentity[];
};
export type OidcClient = OidcClientMetaData & {
callbackURLs: string[]; // No longer requires at least one URL
logoutCallbackURLs: string[];
isPublic: boolean;
pkceEnabled: boolean;
credentials?: OidcClientCredentials;
};
export type OidcClientWithAllowedUserGroups = OidcClient & {

View File

@@ -8,6 +8,7 @@
import * as Card from '$lib/components/ui/card';
import Label from '$lib/components/ui/label/label.svelte';
import UserGroupSelection from '$lib/components/user-group-selection.svelte';
import { m } from '$lib/paraglide/messages';
import OidcService from '$lib/services/oidc-service';
import clientSecretStore from '$lib/stores/client-secret-store';
import type { OidcClientCreateWithLogo } from '$lib/types/oidc.type';
@@ -16,7 +17,6 @@
import { toast } from 'svelte-sonner';
import { slide } from 'svelte/transition';
import OidcForm from '../oidc-client-form.svelte';
import { m } from '$lib/paraglide/messages';
let { data } = $props();
let client = $state({
@@ -166,7 +166,7 @@
</Card.Content>
</Card.Root>
<Card.Root>
<Card.Content class="p-5">
<Card.Content>
<OidcForm existingClient={client} callback={updateClient} />
</Card.Content>
</Card.Root>

View File

@@ -0,0 +1,128 @@
<script lang="ts">
import FormInput from '$lib/components/form/form-input.svelte';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import Label from '$lib/components/ui/label/label.svelte';
import { m } from '$lib/paraglide/messages';
import type { OidcClient, OidcClientFederatedIdentity } from '$lib/types/oidc.type';
import { LucideMinus, LucidePlus } from '@lucide/svelte';
import type { Snippet } from 'svelte';
import type { HTMLAttributes } from 'svelte/elements';
let {
client,
federatedIdentities = $bindable([]),
error = $bindable(null),
...restProps
}: HTMLAttributes<HTMLDivElement> & {
client?: OidcClient;
federatedIdentities: OidcClientFederatedIdentity[];
error?: string | null;
children?: Snippet;
} = $props();
function addFederatedIdentity() {
federatedIdentities = [
...federatedIdentities,
{
issuer: '',
subject: '',
audience: '',
jwks: ''
}
];
}
function removeFederatedIdentity(index: number) {
federatedIdentities = federatedIdentities.filter((_, i) => i !== index);
}
function updateFederatedIdentity(
index: number,
field: keyof OidcClientFederatedIdentity,
value: string
) {
federatedIdentities[index] = {
...federatedIdentities[index],
[field]: value
};
}
</script>
<div {...restProps}>
<FormInput label={m.federated_identities()} description={m.federated_identities_description()}>
<div class="space-y-4">
{#each federatedIdentities as identity, i}
<div class="space-y-3 rounded-lg border p-4">
<div class="flex items-center justify-between">
<Label class="text-sm font-medium">Identity {i + 1}</Label>
{#if federatedIdentities.length > 0}
<Button
variant="outline"
size="sm"
onclick={() => removeFederatedIdentity(i)}
aria-label="Remove federated identity"
>
<LucideMinus class="size-4" />
</Button>
{/if}
</div>
<div class="grid grid-cols-1 gap-3 md:grid-cols-2">
<div>
<Label for="issuer-{i}" class="text-xs">Issuer (Required)</Label>
<Input
id="issuer-{i}"
placeholder="https://example.com/"
value={identity.issuer}
oninput={(e) => updateFederatedIdentity(i, 'issuer', e.currentTarget.value)}
required
/>
</div>
<div>
<Label for="subject-{i}" class="text-xs">Subject (Optional)</Label>
<Input
id="subject-{i}"
placeholder="Defaults to the client ID: {client?.id}"
value={identity.subject || ''}
oninput={(e) => updateFederatedIdentity(i, 'subject', e.currentTarget.value)}
/>
</div>
<div>
<Label for="audience-{i}" class="text-xs">Audience (Optional)</Label>
<Input
id="audience-{i}"
placeholder="Defaults to the Pocket ID URL"
value={identity.audience || ''}
oninput={(e) => updateFederatedIdentity(i, 'audience', e.currentTarget.value)}
/>
</div>
<div>
<Label for="jwks-{i}" class="text-xs">JWKS URL (Optional)</Label>
<Input
id="jwks-{i}"
placeholder="Defaults to {identity.issuer || '<issuer>'}/.well-known/jwks.json"
value={identity.jwks || ''}
oninput={(e) => updateFederatedIdentity(i, 'jwks', e.currentTarget.value)}
/>
</div>
</div>
</div>
{/each}
</div>
</FormInput>
{#if error}
<p class="text-destructive mt-1 text-xs">{error}</p>
{/if}
<Button class="mt-3" variant="secondary" size="sm" onclick={addFederatedIdentity} type="button">
<LucidePlus class="mr-1 size-4" />
{federatedIdentities.length === 0
? m.add_federated_identity()
: m.add_another_federated_identity()}
</Button>
</div>

View File

@@ -5,14 +5,14 @@
import { Button } from '$lib/components/ui/button';
import Label from '$lib/components/ui/label/label.svelte';
import { m } from '$lib/paraglide/messages';
import type {
OidcClient,
OidcClientCreate,
OidcClientCreateWithLogo
} from '$lib/types/oidc.type';
import type { OidcClient, OidcClientCreateWithLogo } from '$lib/types/oidc.type';
import { preventDefault } from '$lib/utils/event-util';
import { createForm } from '$lib/utils/form-util';
import { cn } from '$lib/utils/style';
import { LucideChevronDown } from '@lucide/svelte';
import { slide } from 'svelte/transition';
import { z } from 'zod';
import FederatedIdentitiesInput from './federated-identities-input.svelte';
import OidcCallbackUrlInput from './oidc-callback-url-input.svelte';
let {
@@ -24,17 +24,21 @@
} = $props();
let isLoading = $state(false);
let showAdvancedOptions = $state(false);
let logo = $state<File | null | undefined>();
let logoDataURL: string | null = $state(
existingClient?.hasLogo ? `/api/oidc/clients/${existingClient!.id}/logo` : null
);
const client: OidcClientCreate = {
const client = {
name: existingClient?.name || '',
callbackURLs: existingClient?.callbackURLs || [],
logoutCallbackURLs: existingClient?.logoutCallbackURLs || [],
isPublic: existingClient?.isPublic || false,
pkceEnabled: existingClient?.pkceEnabled || false
pkceEnabled: existingClient?.pkceEnabled || false,
credentials: {
federatedIdentities: existingClient?.credentials?.federatedIdentities || []
}
};
const formSchema = z.object({
@@ -42,7 +46,17 @@
callbackURLs: z.array(z.string().nonempty()).default([]),
logoutCallbackURLs: z.array(z.string().nonempty()),
isPublic: z.boolean(),
pkceEnabled: z.boolean()
pkceEnabled: z.boolean(),
credentials: z.object({
federatedIdentities: z.array(
z.object({
issuer: z.string().url(),
subject: z.string().optional(),
audience: z.string().optional(),
jwks: z.string().url().optional().or(z.literal(''))
})
)
})
});
type FormSchema = typeof formSchema;
@@ -139,8 +153,31 @@
</div>
</div>
</div>
<div class="w-full"></div>
<div class="mt-5 flex justify-end">
<Button {isLoading} type="submit">{m.save()}</Button>
{#if showAdvancedOptions}
<div class="mt-5 md:col-span-2" transition:slide={{ duration: 200 }}>
<FederatedIdentitiesInput
client={existingClient}
bind:federatedIdentities={$inputs.credentials.value.federatedIdentities}
bind:error={$inputs.credentials.error}
/>
</div>
{/if}
<div class="relative mt-5 flex justify-center">
<Button
variant="ghost"
class="text-muted-foregroun"
onclick={() => (showAdvancedOptions = !showAdvancedOptions)}
>
{showAdvancedOptions ? m.hide_advanced_options() : m.show_advanced_options()}
<LucideChevronDown
class={cn(
'size-5 transition-transform duration-200',
showAdvancedOptions && 'rotate-180 transform'
)}
/>
</Button>
<Button {isLoading} type="submit" class="absolute right-0">{m.save()}</Button>
</div>
</form>

View File

@@ -35,6 +35,17 @@ export const oidcClients = {
callbackUrl: 'http://immich/auth/callback',
secret: 'PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x'
},
federated: {
id: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
name: 'Federated',
callbackUrl: 'http://federated/auth/callback',
federatedJWT: {
issuer: 'https://external-idp.local',
audience: 'api://PocketID',
subject: 'c48232ff-ff65-45ed-ae96-7afa8a9b443b',
},
accessCodes: ['federated']
},
pingvinShare: {
name: 'Pingvin Share',
callbackUrl: 'http://pingvin.share/auth/callback',

View File

@@ -7,6 +7,7 @@
"devDependencies": {
"@playwright/test": "^1.52.0",
"@types/node": "^22.15.21",
"dotenv": "^16.5.0",
"jose": "^6.0.11"
}
},
@@ -36,6 +37,19 @@
"undici-types": "~6.21.0"
}
},
"node_modules/dotenv": {
"version": "16.5.0",
"resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.5.0.tgz",
"integrity": "sha512-m/C+AwOAr9/W1UOIZUo232ejMNnJAJtYQjUbHoNTBNTJSvqzzDh7vnrei3o3r3m9blf6ZoDkvcw0VmozNRFJxg==",
"dev": true,
"license": "BSD-2-Clause",
"engines": {
"node": ">=12"
},
"funding": {
"url": "https://dotenvx.com"
}
},
"node_modules/fsevents": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",

View File

@@ -3,6 +3,7 @@
"devDependencies": {
"@playwright/test": "^1.52.0",
"@types/node": "^22.15.21",
"jose": "^6.0.11"
"jose": "^6.0.11",
"dotenv": "^16.5.0"
}
}

View File

@@ -1,30 +1,31 @@
import { defineConfig, devices } from '@playwright/test';
import { defineConfig, devices } from "@playwright/test";
import "dotenv/config";
/**
* See https://playwright.dev/docs/test-configuration.
*/
export default defineConfig({
outputDir: './.output',
timeout: 10000,
testDir: './specs',
fullyParallel: false,
forbidOnly: !!process.env.CI,
retries: process.env.CI ? 1 : 0,
workers: 1,
reporter: process.env.CI
? [['html', { outputFolder: '.report' }], ['github']]
: [['line'], ['html', { open: 'never', outputFolder: '.report' }]],
use: {
baseURL: process.env.APP_URL ?? 'http://localhost:1411',
video: 'retain-on-failure',
trace: 'on-first-retry'
},
projects: [
{ name: 'setup', testMatch: /.*\.setup\.ts/ },
{
name: 'chromium',
use: { ...devices['Desktop Chrome'], storageState: '.auth/user.json' },
dependencies: ['setup']
}
]
outputDir: "./.output",
timeout: 10000,
testDir: "./specs",
fullyParallel: false,
forbidOnly: !!process.env.CI,
retries: process.env.CI ? 1 : 0,
workers: 1,
reporter: process.env.CI
? [["html", { outputFolder: ".report" }], ["github"]]
: [["line"], ["html", { open: "never", outputFolder: ".report" }]],
use: {
baseURL: process.env.APP_URL ?? "http://localhost:1411",
video: "retain-on-failure",
trace: "on-first-retry",
},
projects: [
{ name: "setup", testMatch: /.*\.setup\.ts/ },
{
name: "chromium",
use: { ...devices["Desktop Chrome"], storageState: ".auth/user.json" },
dependencies: ["setup"],
},
],
});

View File

@@ -4,6 +4,8 @@ import { cleanupBackend } from '../utils/cleanup.util';
test.beforeEach(cleanupBackend);
test.describe('LDAP Integration', () => {
test.skip(process.env.SKIP_LDAP_TESTS === "true", 'Skipping LDAP tests due to SKIP_LDAP_TESTS environment variable');
test('LDAP configuration is working properly', async ({ page }) => {
await page.goto('/settings/admin/application-configuration');

View File

@@ -2,7 +2,7 @@ import test, { expect } from "@playwright/test";
import { oidcClients, refreshTokens, users } from "../data";
import { cleanupBackend } from "../utils/cleanup.util";
import { generateIdToken, generateOauthAccessToken } from "../utils/jwt.util";
import oidcUtil from "../utils/oidc.util";
import * as oidcUtil from "../utils/oidc.util";
import passkeyUtil from "../utils/passkey.util";
test.beforeEach(cleanupBackend);
@@ -449,3 +449,40 @@ test("Authorize new client with device authorization with user group not allowed
.filter({ hasText: "You're not allowed to access this service." })
).toBeVisible();
});
test("Federated identity fails with invalid client assertion", async ({
page,
}) => {
const client = oidcClients.federated;
const res = await oidcUtil.exchangeCode(page, {
client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
grant_type: 'authorization_code',
redirect_uri: client.callbackUrl,
code: client.accessCodes[0],
client_id: client.id,
client_assertion:'not-an-assertion',
});
expect(res?.error).toBe('Invalid client assertion');
});
test("Authorize existing client with federated identity", async ({
page,
}) => {
const client = oidcClients.federated;
const clientAssertion = await oidcUtil.getClientAssertion(page, client.federatedJWT);
const res = await oidcUtil.exchangeCode(page, {
client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
grant_type: 'authorization_code',
redirect_uri: client.callbackUrl,
code: client.accessCodes[0],
client_id: client.id,
client_assertion: clientAssertion,
});
expect(res.access_token).not.toBeNull;
expect(res.expires_in).not.toBeNull;
expect(res.token_type).toBe('Bearer');
});

View File

@@ -1,5 +1,6 @@
{
"compilerOptions": {
"baseUrl": "."
"baseUrl": ".",
"lib": ["ES2022"]
}
}

View File

@@ -1,12 +1,15 @@
import playwrightConfig from "../playwright.config";
export async function cleanupBackend() {
const response = await fetch(
playwrightConfig.use!.baseURL + "/api/test/reset",
{
method: "POST",
}
);
const url = new URL("/api/test/reset", playwrightConfig.use!.baseURL);
if (process.env.SKIP_LDAP_TESTS === "true") {
url.searchParams.append("skip-ldap", "true");
}
const response = await fetch(url, {
method: "POST",
});
if (!response.ok) {
throw new Error(

View File

@@ -1,7 +1,7 @@
import type { Page } from '@playwright/test';
async function getUserCode(page: Page, clientId: string, clientSecret: string) {
const response = await page.request
export async function getUserCode(page: Page, clientId: string, clientSecret: string): Promise<string> {
return page.request
.post('/api/oidc/device/authorize', {
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
@@ -12,11 +12,29 @@ async function getUserCode(page: Page, clientId: string, clientSecret: string) {
scope: 'openid profile email'
}
})
.then((r) => r.json());
return response.user_code;
.then((r) => r.json())
.then((r) => r.user_code);
}
export default {
getUserCode
};
export async function exchangeCode(page: Page, params: Record<string,string>): Promise<{access_token?: string, token_type?: string, expires_in?: number, error?: string}> {
return page.request
.post('/api/oidc/token', {
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
},
form: params,
})
.then((r) => r.json());
}
export async function getClientAssertion(page: Page, data: {issuer: string, audience: string, subject: string}): Promise<string> {
return page.request
.post('/api/externalidp/sign', {
data: {
iss: data.issuer,
aud: data.audience,
sub: data.subject,
},
})
.then((r) => r.text());
}