mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-06 05:12:57 +03:00
feat: JWT bearer assertions for client authentication (#566)
Co-authored-by: Kyle Mendell <ksm@ofkm.us> Co-authored-by: Kyle Mendell <kmendell@ofkm.us> Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
035b2c022b
commit
05bfe00924
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,6 +10,7 @@ node_modules
|
||||
/frontend/build
|
||||
/backend/bin
|
||||
pocket-id
|
||||
/tests/test-results/*.json
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
@@ -20,7 +20,8 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hashicorp/go-uuid v1.0.3
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1
|
||||
github.com/mileusna/useragent v1.3.5
|
||||
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
@@ -32,7 +33,7 @@ require (
|
||||
go.opentelemetry.io/otel/sdk v1.35.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
||||
go.opentelemetry.io/otel/trace v1.35.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/image v0.24.0
|
||||
golang.org/x/time v0.9.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
@@ -77,9 +78,8 @@ require (
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
|
||||
github.com/lestrrat-go/httpcc v1.0.1 // indirect
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect
|
||||
github.com/lestrrat-go/option v1.0.1 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@@ -123,7 +123,7 @@ require (
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sync v0.14.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
|
||||
google.golang.org/grpc v1.71.0 // indirect
|
||||
|
||||
@@ -164,14 +164,14 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
|
||||
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
|
||||
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
|
||||
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 h1:SDxjGoH7qj0nBXVrcrxX8eD94wEnjR+EEuqqmeqQYlY=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2/go.mod h1:Nwo81sMxE0DcvTB+rJyynNhv/DUu2yZErV7sscw9pHE=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1 h1:fH3T748FCMbXoF9UXXNS9i0q6PpYyJZK/rKSbkt2guY=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1/go.mod h1:XP2WqxMOSzHSyf3pfibCcfsLqbomxakAnNqiuaH8nwo=
|
||||
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
||||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
@@ -309,8 +309,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
|
||||
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -377,8 +377,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -14,7 +16,12 @@ import (
|
||||
func init() {
|
||||
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
|
||||
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
|
||||
testService := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
|
||||
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize test service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
controller.NewTestController(apiGroup, testService)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -26,15 +26,14 @@ type services struct {
|
||||
}
|
||||
|
||||
// Initializes all services
|
||||
// The context should be used by services only for initialization, and not for running
|
||||
func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
|
||||
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
|
||||
svc = &services{}
|
||||
|
||||
svc.appConfigService = service.NewAppConfigService(initCtx, db)
|
||||
svc.appConfigService = service.NewAppConfigService(ctx, db)
|
||||
|
||||
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create email service: %w", err)
|
||||
return nil, fmt.Errorf("failed to create email service: %w", err)
|
||||
}
|
||||
|
||||
svc.geoLiteService = service.NewGeoLiteService(httpClient)
|
||||
@@ -42,7 +41,12 @@ func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client)
|
||||
svc.jwtService = service.NewJwtService(svc.appConfigService)
|
||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
|
||||
svc.customClaimService = service.NewCustomClaimService(db)
|
||||
svc.oidcService = service.NewOidcService(db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
|
||||
|
||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
||||
}
|
||||
|
||||
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||
|
||||
@@ -65,6 +65,11 @@ type OidcClientSecretInvalidError struct{}
|
||||
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
|
||||
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcClientAssertionInvalidError struct{}
|
||||
|
||||
func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" }
|
||||
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcInvalidAuthorizationCodeError struct{}
|
||||
|
||||
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
|
||||
|
||||
@@ -14,6 +14,9 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
|
||||
testController := &TestController{TestService: testService}
|
||||
|
||||
group.POST("/test/reset", testController.resetAndSeedHandler)
|
||||
|
||||
group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
|
||||
group.POST("/externalidp/sign", testController.externalIdPSignToken)
|
||||
}
|
||||
|
||||
type TestController struct {
|
||||
@@ -21,6 +24,15 @@ type TestController struct {
|
||||
}
|
||||
|
||||
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
var baseURL string
|
||||
if c.Request.TLS != nil {
|
||||
baseURL = "https://" + c.Request.Host
|
||||
} else {
|
||||
baseURL = "http://" + c.Request.Host
|
||||
}
|
||||
|
||||
skipLdap := c.Query("skip-ldap") == "true"
|
||||
|
||||
if err := tc.TestService.ResetDatabase(); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -31,7 +43,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SeedDatabase(); err != nil {
|
||||
if err := tc.TestService.SeedDatabase(baseURL); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -41,17 +53,50 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
if !skipLdap {
|
||||
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tc.TestService.SetJWTKeys()
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (tc *TestController) externalIdPJWKS(c *gin.Context) {
|
||||
jwks, err := tc.TestService.GetExternalIdPJWKS()
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, jwks)
|
||||
}
|
||||
|
||||
func (tc *TestController) externalIdPSignToken(c *gin.Context) {
|
||||
var input struct {
|
||||
Aud string `json:"aud"`
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
err := c.ShouldBindJSON(&input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := tc.TestService.SignExternalIdPToken(input.Iss, input.Sub, input.Aud)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.WriteString(token)
|
||||
}
|
||||
|
||||
@@ -7,14 +7,14 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
||||
)
|
||||
|
||||
// NewOidcController creates a new controller for OIDC related endpoints
|
||||
@@ -124,11 +124,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
|
||||
// @Tags OIDC
|
||||
// @Produce json
|
||||
// @Param client_id formData string false "Client ID (if not using Basic Auth)"
|
||||
// @Param client_secret formData string false "Client secret (if not using Basic Auth)"
|
||||
// @Param client_secret formData string false "Client secret (if not using Basic Auth or client assertions)"
|
||||
// @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
|
||||
// @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
|
||||
// @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
|
||||
// @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
|
||||
// @Param client_assertion formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
|
||||
// @Param client_assertion_type formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
|
||||
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
|
||||
// @Router /api/oidc/token [post]
|
||||
func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
@@ -363,12 +365,12 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
|
||||
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
|
||||
err = dto.MapStruct(client, &clientDto)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = c.Error(err)
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
}
|
||||
|
||||
// listClientsHandler godoc
|
||||
|
||||
@@ -62,7 +62,60 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func mapField(sourceField reflect.Value, destField reflect.Value) error {
|
||||
// Handle pointer to struct in source
|
||||
if sourceField.Kind() == reflect.Ptr && !sourceField.IsNil() {
|
||||
switch {
|
||||
case sourceField.Elem().Kind() == reflect.Struct:
|
||||
switch {
|
||||
case destField.Kind() == reflect.Struct:
|
||||
// Map from pointer to struct -> struct
|
||||
return mapStructInternal(sourceField.Elem(), destField)
|
||||
case destField.Kind() == reflect.Ptr && destField.CanSet():
|
||||
// Map from pointer to struct -> pointer to struct
|
||||
if destField.IsNil() {
|
||||
destField.Set(reflect.New(destField.Type().Elem()))
|
||||
}
|
||||
return mapStructInternal(sourceField.Elem(), destField.Elem())
|
||||
}
|
||||
case destField.Kind() == reflect.Ptr &&
|
||||
destField.CanSet() &&
|
||||
sourceField.Elem().Type().AssignableTo(destField.Type().Elem()):
|
||||
// Handle primitive pointer types (e.g., *string to *string)
|
||||
if destField.IsNil() {
|
||||
destField.Set(reflect.New(destField.Type().Elem()))
|
||||
}
|
||||
destField.Elem().Set(sourceField.Elem())
|
||||
return nil
|
||||
case destField.Kind() != reflect.Ptr &&
|
||||
destField.CanSet() &&
|
||||
sourceField.Elem().Type().AssignableTo(destField.Type()):
|
||||
// Handle *T to T conversion for primitive types
|
||||
destField.Set(sourceField.Elem())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle pointer to struct in destination
|
||||
if destField.Kind() == reflect.Ptr && destField.CanSet() {
|
||||
switch {
|
||||
case sourceField.Kind() == reflect.Struct:
|
||||
// Map from struct -> pointer to struct
|
||||
if destField.IsNil() {
|
||||
destField.Set(reflect.New(destField.Type().Elem()))
|
||||
}
|
||||
return mapStructInternal(sourceField, destField.Elem())
|
||||
case !sourceField.IsZero() && sourceField.Type().AssignableTo(destField.Type().Elem()):
|
||||
// Handle T to *T conversion for primitive types
|
||||
if destField.IsNil() {
|
||||
destField.Set(reflect.New(destField.Type().Elem()))
|
||||
}
|
||||
destField.Elem().Set(sourceField)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case sourceField.Type() == destField.Type():
|
||||
destField.Set(sourceField)
|
||||
|
||||
@@ -8,10 +8,11 @@ type OidcClientMetaDataDto struct {
|
||||
|
||||
type OidcClientDto struct {
|
||||
OidcClientMetaDataDto
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||
}
|
||||
|
||||
type OidcClientWithAllowedUserGroupsDto struct {
|
||||
@@ -25,11 +26,23 @@ type OidcClientWithAllowedGroupsCountDto struct {
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
Name string `json:"name" binding:"required,max=50"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Name string `json:"name" binding:"required,max=50"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||
}
|
||||
|
||||
type OidcClientCredentialsDto struct {
|
||||
FederatedIdentities []OidcClientFederatedIdentityDto `json:"federatedIdentities,omitempty"`
|
||||
}
|
||||
|
||||
type OidcClientFederatedIdentityDto struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
Audience string `json:"audience,omitempty"`
|
||||
JWKS string `json:"jwks,omitempty"`
|
||||
}
|
||||
|
||||
type AuthorizeOidcClientRequestDto struct {
|
||||
@@ -52,13 +65,15 @@ type AuthorizationRequiredDto struct {
|
||||
}
|
||||
|
||||
type OidcCreateTokensDto struct {
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code"`
|
||||
DeviceCode string `form:"device_code"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
CodeVerifier string `form:"code_verifier"`
|
||||
RefreshToken string `form:"refresh_token"`
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code"`
|
||||
DeviceCode string `form:"device_code"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
CodeVerifier string `form:"code_verifier"`
|
||||
RefreshToken string `form:"refresh_token"`
|
||||
ClientAssertion string `form:"client_assertion"`
|
||||
ClientAssertionType string `form:"client_assertion_type"`
|
||||
}
|
||||
|
||||
type OidcIntrospectDto struct {
|
||||
|
||||
@@ -5,8 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"gorm.io/gorm"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
type UserAuthorizedOidcClient struct {
|
||||
@@ -45,6 +46,7 @@ type OidcClient struct {
|
||||
HasLogo bool `gorm:"-"`
|
||||
IsPublic bool
|
||||
PkceEnabled bool
|
||||
Credentials OidcClientCredentials
|
||||
|
||||
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||
CreatedByID string
|
||||
@@ -71,9 +73,49 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
type OidcClientCredentials struct { //nolint:recvcheck
|
||||
FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"`
|
||||
}
|
||||
|
||||
type OidcClientFederatedIdentity struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
Audience string `json:"audience,omitempty"`
|
||||
JWKS string `json:"jwks,omitempty"` // URL of the JWKS
|
||||
}
|
||||
|
||||
func (occ OidcClientCredentials) FederatedIdentityForIssuer(issuer string) (OidcClientFederatedIdentity, bool) {
|
||||
if issuer == "" {
|
||||
return OidcClientFederatedIdentity{}, false
|
||||
}
|
||||
|
||||
for _, fi := range occ.FederatedIdentities {
|
||||
if fi.Issuer == issuer {
|
||||
return fi, true
|
||||
}
|
||||
}
|
||||
|
||||
return OidcClientFederatedIdentity{}, false
|
||||
}
|
||||
|
||||
func (occ *OidcClientCredentials) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(v, occ)
|
||||
case string:
|
||||
return json.Unmarshal([]byte(v), occ)
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
func (occ OidcClientCredentials) Value() (driver.Value, error) {
|
||||
return json.Marshal(occ)
|
||||
}
|
||||
|
||||
type UrlList []string //nolint:recvcheck
|
||||
|
||||
func (cu *UrlList) Scan(value interface{}) error {
|
||||
func (cu *UrlList) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(v, cu)
|
||||
|
||||
@@ -29,17 +29,17 @@ type AppConfigService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAppConfigService(initCtx context.Context, db *gorm.DB) *AppConfigService {
|
||||
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
|
||||
err := service.LoadDbConfig(initCtx)
|
||||
err := service.LoadDbConfig(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize app config service: %v", err)
|
||||
}
|
||||
|
||||
err = service.initInstanceID(initCtx)
|
||||
err = service.initInstanceID(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize instance ID: %v", err)
|
||||
}
|
||||
|
||||
@@ -3,16 +3,10 @@ package service
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -28,7 +22,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
|
||||
|
||||
func TestLoadDbConfig(t *testing.T) {
|
||||
t.Run("empty config table", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
@@ -42,7 +36,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("loads value from config table", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Populate the config table with some initial values
|
||||
err := db.
|
||||
@@ -72,7 +66,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ignores unknown config keys", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Add an entry with a key that doesn't exist in the config struct
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
@@ -93,7 +87,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("loading config multiple times", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Initial state
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
@@ -135,7 +129,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
common.EnvConfig.UiConfigDisabled = true
|
||||
|
||||
// Create database with config that should be ignored
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
{Key: "appName", Value: "DB App"},
|
||||
{Key: "sessionDuration", Value: "120"},
|
||||
@@ -171,7 +165,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
common.EnvConfig.UiConfigDisabled = false
|
||||
|
||||
// Create database with config values that should take precedence
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
{Key: "appName", Value: "DB App"},
|
||||
{Key: "sessionDuration", Value: "120"},
|
||||
@@ -195,7 +189,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
|
||||
func TestUpdateAppConfigValues(t *testing.T) {
|
||||
t.Run("update single value", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -220,7 +214,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update multiple values", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -264,7 +258,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty value resets to default", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -285,7 +279,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error with odd number of arguments", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -301,7 +295,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error with invalid key", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -319,7 +313,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
|
||||
func TestUpdateAppConfig(t *testing.T) {
|
||||
t.Run("updates configuration values from DTO", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -392,7 +386,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty values reset to defaults", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config and modify some values
|
||||
service := &AppConfigService{
|
||||
@@ -457,7 +451,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
// Disable UI config
|
||||
common.EnvConfig.UiConfigDisabled = true
|
||||
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := newDatabaseForTest(t)
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
@@ -475,49 +469,3 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
require.ErrorAs(t, err, &uiConfigDisabledErr)
|
||||
})
|
||||
}
|
||||
|
||||
// Implements gorm's logger.Writer interface
|
||||
type testLoggerAdapter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
||||
l.t.Logf(format, args...)
|
||||
}
|
||||
|
||||
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
// Get a name for this in-memory database that is specific to the test
|
||||
dbName := utils.CreateSha256Hash(t.Name())
|
||||
|
||||
// Connect to a new in-memory SQL database
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
|
||||
&gorm.Config{
|
||||
TranslateError: true,
|
||||
Logger: logger.New(
|
||||
testLoggerAdapter{t: t},
|
||||
logger.Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
ParameterizedQueries: false,
|
||||
Colorful: false,
|
||||
},
|
||||
),
|
||||
})
|
||||
require.NoError(t, err, "Failed to connect to test database")
|
||||
|
||||
// Create the app_config_variables table
|
||||
err = db.Exec(`
|
||||
CREATE TABLE app_config_variables
|
||||
(
|
||||
key VARCHAR(100) NOT NULL PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
)
|
||||
`).Error
|
||||
require.NoError(t, err, "Failed to create test config table")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@@ -16,6 +18,7 @@ import (
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
@@ -30,14 +33,43 @@ type TestService struct {
|
||||
jwtService *JwtService
|
||||
appConfigService *AppConfigService
|
||||
ldapService *LdapService
|
||||
externalIdPKey jwk.Key
|
||||
}
|
||||
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) *TestService {
|
||||
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService, ldapService: ldapService}
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) {
|
||||
s := &TestService{
|
||||
db: db,
|
||||
appConfigService: appConfigService,
|
||||
jwtService: jwtService,
|
||||
ldapService: ldapService,
|
||||
}
|
||||
err := s.initExternalIdP()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize external IdP: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Initializes the "external IdP"
|
||||
// This creates a new "issuing authority" containing a public JWKS
|
||||
// It also stores the private key internally that will be used to issue JWTs
|
||||
func (s *TestService) initExternalIdP() error {
|
||||
// Generate a new ECDSA key
|
||||
rawKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
s.externalIdPKey, err = utils.ImportRawKey(rawKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *TestService) SeedDatabase() error {
|
||||
func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
users := []model.User{
|
||||
{
|
||||
@@ -138,6 +170,26 @@ func (s *TestService) SeedDatabase() error {
|
||||
userGroups[1],
|
||||
},
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
|
||||
},
|
||||
Name: "Federated",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURLs: model.UrlList{"http://federated/auth/callback"},
|
||||
CreatedByID: users[1].ID,
|
||||
AllowedUserGroups: []model.UserGroup{},
|
||||
Credentials: model.OidcClientCredentials{
|
||||
FederatedIdentities: []model.OidcClientFederatedIdentity{
|
||||
{
|
||||
Issuer: "https://external-idp.local",
|
||||
Audience: "api://PocketID",
|
||||
Subject: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
|
||||
JWKS: baseURL + "/api/externalidp/jwks.json",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, client := range oidcClients {
|
||||
if err := tx.Create(&client).Error; err != nil {
|
||||
@@ -145,16 +197,28 @@ func (s *TestService) SeedDatabase() error {
|
||||
}
|
||||
}
|
||||
|
||||
authCode := model.OidcAuthorizationCode{
|
||||
Code: "auth-code",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
authCodes := []model.OidcAuthorizationCode{
|
||||
{
|
||||
Code: "auth-code",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
},
|
||||
{
|
||||
Code: "federated",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[1].ID,
|
||||
ClientID: oidcClients[2].ID,
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&authCode).Error; err != nil {
|
||||
return err
|
||||
for _, authCode := range authCodes {
|
||||
if err := tx.Create(&authCode).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
refreshToken := model.OidcRefreshToken{
|
||||
@@ -177,13 +241,22 @@ func (s *TestService) SeedDatabase() error {
|
||||
return err
|
||||
}
|
||||
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
Scope: "openid profile email",
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
userAuthorizedClients := []model.UserAuthorizedOidcClient{
|
||||
{
|
||||
Scope: "openid profile email",
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
},
|
||||
{
|
||||
Scope: "openid profile email",
|
||||
UserID: users[1].ID,
|
||||
ClientID: oidcClients[2].ID,
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
|
||||
return err
|
||||
for _, userAuthorizedClient := range userAuthorizedClients {
|
||||
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// To generate a new key pair, run the following command:
|
||||
@@ -405,3 +478,41 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExternalIdPJWKS returns the JWKS for the "external IdP".
|
||||
func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
|
||||
pubKey, err := s.externalIdPKey.PublicKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
set := jwk.NewSet()
|
||||
err = set.AddKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add public key to set: %w", err)
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func (s *TestService) SignExternalIdPToken(iss, sub, aud string) (string, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(sub).
|
||||
Expiration(now.Add(time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(iss).
|
||||
Audience([]string{aud}).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
alg, _ := s.externalIdPKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.externalIdPKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
@@ -4,11 +4,9 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -372,7 +370,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
EnsureAlgInKey(pubKey)
|
||||
utils.EnsureAlgInKey(pubKey)
|
||||
|
||||
return pubKey, nil
|
||||
}
|
||||
@@ -415,27 +413,6 @@ func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
|
||||
// We generate RSA keys only
|
||||
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
|
||||
@@ -444,27 +421,7 @@ func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
|
||||
}
|
||||
|
||||
// Import the raw key
|
||||
return importRawKey(rawKey)
|
||||
}
|
||||
|
||||
func importRawKey(rawKey any) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key)
|
||||
|
||||
return key, err
|
||||
return utils.ImportRawKey(rawKey)
|
||||
}
|
||||
|
||||
// SaveKeyJWK saves a JWK to a file
|
||||
@@ -492,16 +449,6 @@ func SaveKeyJWK(key jwk.Key, path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// GetIsAdmin returns the value of the "isAdmin" claim in the token
|
||||
func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
if !token.Has(IsAdminClaim) {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
func TestJwtService_Init(t *testing.T) {
|
||||
@@ -1218,7 +1219,7 @@ func TestTokenTypeValidator(t *testing.T) {
|
||||
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := importRawKey(privateKeyRaw)
|
||||
privateKey, err := utils.ImportRawKey(privateKeyRaw)
|
||||
require.NoError(t, err, "Failed to import private key")
|
||||
|
||||
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
|
||||
|
||||
@@ -3,18 +3,25 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/httprc/v3"
|
||||
"github.com/lestrrat-go/httprc/v3/errsink"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
@@ -31,6 +38,8 @@ const (
|
||||
GrantTypeAuthorizationCode = "authorization_code"
|
||||
GrantTypeRefreshToken = "refresh_token"
|
||||
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
@@ -39,16 +48,61 @@ type OidcService struct {
|
||||
appConfigService *AppConfigService
|
||||
auditLogService *AuditLogService
|
||||
customClaimService *CustomClaimService
|
||||
|
||||
httpClient *http.Client
|
||||
jwkCache *jwk.Cache
|
||||
}
|
||||
|
||||
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
|
||||
return &OidcService{
|
||||
func NewOidcService(
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
jwtService *JwtService,
|
||||
appConfigService *AppConfigService,
|
||||
auditLogService *AuditLogService,
|
||||
customClaimService *CustomClaimService,
|
||||
) (s *OidcService, err error) {
|
||||
s = &OidcService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
appConfigService: appConfigService,
|
||||
auditLogService: auditLogService,
|
||||
customClaimService: customClaimService,
|
||||
}
|
||||
|
||||
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
|
||||
s.jwkCache, err = s.getJWKCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
|
||||
// We need to create a custom HTTP client to set a timeout.
|
||||
client := s.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
|
||||
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
// Indicates a development-time error
|
||||
panic("Default transport is not of type *http.Transport")
|
||||
}
|
||||
transport := defaultTransport.Clone()
|
||||
transport.TLSClientConfig.MinVersion = tls.VersionTLS12
|
||||
client.Transport = transport
|
||||
}
|
||||
|
||||
// Create the JWKS cache
|
||||
return jwk.NewCache(ctx,
|
||||
httprc.NewClient(
|
||||
httprc.WithErrorSink(errsink.NewSlog(slog.Default())),
|
||||
httprc.WithHTTPClient(client),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
@@ -198,7 +252,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -279,7 +333,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -357,7 +411,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -420,7 +474,10 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
})
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
@@ -440,33 +497,35 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
introspectDto.Active = true
|
||||
introspectDto.TokenType = "access_token"
|
||||
if token.Has("scope") {
|
||||
var asString string
|
||||
var asStrings []string
|
||||
var (
|
||||
asString string
|
||||
asStrings []string
|
||||
)
|
||||
if err := token.Get("scope", &asString); err == nil {
|
||||
introspectDto.Scope = asString
|
||||
} else if err := token.Get("scope", &asStrings); err == nil {
|
||||
introspectDto.Scope = strings.Join(asStrings, " ")
|
||||
}
|
||||
}
|
||||
if expiration, hasExpiration := token.Expiration(); hasExpiration {
|
||||
if expiration, ok := token.Expiration(); ok {
|
||||
introspectDto.Expiration = expiration.Unix()
|
||||
}
|
||||
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
|
||||
if issuedAt, ok := token.IssuedAt(); ok {
|
||||
introspectDto.IssuedAt = issuedAt.Unix()
|
||||
}
|
||||
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
|
||||
if notBefore, ok := token.NotBefore(); ok {
|
||||
introspectDto.NotBefore = notBefore.Unix()
|
||||
}
|
||||
if subject, hasSubject := token.Subject(); hasSubject {
|
||||
if subject, ok := token.Subject(); ok {
|
||||
introspectDto.Subject = subject
|
||||
}
|
||||
if audience, hasAudience := token.Audience(); hasAudience {
|
||||
if audience, ok := token.Audience(); ok {
|
||||
introspectDto.Audience = audience
|
||||
}
|
||||
if issuer, hasIssuer := token.Issuer(); hasIssuer {
|
||||
if issuer, ok := token.Issuer(); ok {
|
||||
introspectDto.Issuer = issuer
|
||||
}
|
||||
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
|
||||
if identifier, ok := token.JwtID(); ok {
|
||||
introspectDto.Identifier = identifier
|
||||
}
|
||||
|
||||
@@ -542,13 +601,9 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
|
||||
|
||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
LogoutCallbackURLs: input.LogoutCallbackURLs,
|
||||
CreatedByID: userID,
|
||||
IsPublic: input.IsPublic,
|
||||
PkceEnabled: input.PkceEnabled,
|
||||
CreatedByID: userID,
|
||||
}
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
@@ -577,11 +632,7 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
@@ -599,6 +650,29 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientCreateDto) {
|
||||
// Base fields
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
// PKCE is required for public clients
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
|
||||
// Credentials
|
||||
if len(input.Credentials.FederatedIdentities) > 0 {
|
||||
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
|
||||
for i, fi := range input.Credentials.FederatedIdentities {
|
||||
client.Credentials.FederatedIdentities[i] = model.OidcClientFederatedIdentity{
|
||||
Issuer: fi.Issuer,
|
||||
Audience: fi.Audience,
|
||||
Subject: fi.Subject,
|
||||
JWKS: fi.JWKS,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
|
||||
var client model.OidcClient
|
||||
err := s.db.
|
||||
@@ -1079,7 +1153,10 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: input.ClientID,
|
||||
ClientSecret: input.ClientSecret,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1305,33 +1382,140 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
|
||||
// First, ensure we have a valid client ID
|
||||
if clientID == "" {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
if input.ClientID == "" {
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Load the OIDC client's configuration
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
First(&client, "id = ?", input.ClientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we have a client secret, we validate it
|
||||
// Otherwise, we require the client to be public
|
||||
if clientSecret != "" {
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
// We have 3 options
|
||||
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
|
||||
switch {
|
||||
// First, if we have a client secret, we validate it
|
||||
case input.ClientSecret != "":
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(input.ClientSecret))
|
||||
if err != nil {
|
||||
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
|
||||
return nil, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
return &client, nil
|
||||
|
||||
// Next, check if we want to use client assertions from federated identities
|
||||
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
|
||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
|
||||
if err != nil {
|
||||
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
return &client, nil
|
||||
|
||||
// There's no credentials
|
||||
// This is allowed only if the client is public
|
||||
case client.IsPublic:
|
||||
return &client, nil
|
||||
|
||||
// If we're here, we have no credentials AND the client is not public, so credentials are required
|
||||
default:
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set, err error) {
|
||||
// Check if we have already registered the URL
|
||||
if !s.jwkCache.IsRegistered(ctx, url) {
|
||||
// We set a timeout because otherwise Register will keep trying in case of errors
|
||||
registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer registerCancel()
|
||||
// We need to register the URL
|
||||
err = s.jwkCache.Register(
|
||||
registerCtx,
|
||||
url,
|
||||
jwk.WithMaxInterval(24*time.Hour),
|
||||
jwk.WithMinInterval(15*time.Minute),
|
||||
jwk.WithWaitReady(true),
|
||||
)
|
||||
// In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
|
||||
if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) {
|
||||
return nil, fmt.Errorf("failed to register JWK set: %w", err)
|
||||
}
|
||||
return client, nil
|
||||
} else if !client.IsPublic {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
jwks, err := s.jwkCache.CachedSet(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cached JWK set: %w", err)
|
||||
}
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
|
||||
// First, parse the assertion JWT, without validating it, to check the issuer
|
||||
assertion := []byte(input.ClientAssertion)
|
||||
insecureToken, err := jwt.ParseInsecure(assertion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse client assertion JWT: %w", err)
|
||||
}
|
||||
|
||||
issuer, _ := insecureToken.Issuer()
|
||||
if issuer == "" {
|
||||
return errors.New("client assertion does not contain an issuer claim")
|
||||
}
|
||||
|
||||
// Ensure that this client is federated with the one that issued the token
|
||||
ocfi, ok := client.Credentials.FederatedIdentityForIssuer(issuer)
|
||||
if !ok {
|
||||
return fmt.Errorf("client assertion is not from an allowed issuer: %s", issuer)
|
||||
}
|
||||
|
||||
// Get the JWK set for the issuer
|
||||
jwksURL := ocfi.JWKS
|
||||
if jwksURL == "" {
|
||||
// Default URL is from the issuer
|
||||
if strings.HasSuffix(issuer, "/") {
|
||||
jwksURL = issuer + ".well-known/jwks.json"
|
||||
} else {
|
||||
jwksURL = issuer + "/.well-known/jwks.json"
|
||||
}
|
||||
}
|
||||
jwks, err := s.jwkSetForURL(ctx, jwksURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWK set for issuer '%s': %w", issuer, err)
|
||||
}
|
||||
|
||||
// Set default audience and subject if missing
|
||||
audience := ocfi.Audience
|
||||
if audience == "" {
|
||||
// Default to the Pocket ID's URL
|
||||
audience = common.EnvConfig.AppURL
|
||||
}
|
||||
subject := ocfi.Subject
|
||||
if subject == "" {
|
||||
// Default to the client ID, per RFC 7523
|
||||
subject = client.ID
|
||||
}
|
||||
|
||||
// Now re-parse the token with proper validation
|
||||
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
|
||||
_, err = jwt.Parse(assertion,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
|
||||
jwt.WithAudience(audience),
|
||||
jwt.WithSubject(subject),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client assertion is not valid: %w", err)
|
||||
}
|
||||
|
||||
// If we're here, the assertion is valid
|
||||
return nil
|
||||
}
|
||||
|
||||
365
backend/internal/service/oidc_service_test.go
Normal file
365
backend/internal/service/oidc_service_test.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
)
|
||||
|
||||
// generateTestECDSAKey creates an ECDSA key for testing
|
||||
func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
privateJwk, err := jwk.Import(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = privateJwk.Set(jwk.KeyIDKey, "test-key-1")
|
||||
require.NoError(t, err)
|
||||
err = privateJwk.Set(jwk.AlgorithmKey, "ES256")
|
||||
require.NoError(t, err)
|
||||
err = privateJwk.Set("use", "sig")
|
||||
require.NoError(t, err)
|
||||
|
||||
publicJwk, err := jwk.PublicKeyOf(privateJwk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a JWK Set with the public key
|
||||
jwkSet := jwk.NewSet()
|
||||
err = jwkSet.AddKey(publicJwk)
|
||||
require.NoError(t, err)
|
||||
jwkSetJSON, err := json.Marshal(jwkSet)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateJwk, jwkSetJSON
|
||||
}
|
||||
|
||||
func TestOidcService_jwkSetForURL(t *testing.T) {
|
||||
// Generate a test key for JWKS
|
||||
_, jwkSetJSON1 := generateTestECDSAKey(t)
|
||||
_, jwkSetJSON2 := generateTestECDSAKey(t)
|
||||
|
||||
// Create a mock HTTP client with responses for different URLs
|
||||
const (
|
||||
url1 = "https://example.com/.well-known/jwks.json"
|
||||
url2 = "https://other-issuer.com/jwks"
|
||||
)
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
||||
//nolint:bodyclose
|
||||
url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
// Create the OidcService with our mock client
|
||||
s := &OidcService{
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
var err error
|
||||
s.jwkCache, err = s.getJWKCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Fetches and caches JWK set", func(t *testing.T) {
|
||||
jwks, err := s.jwkSetForURL(t.Context(), url1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, jwks)
|
||||
|
||||
// Verify the JWK set contains our key
|
||||
require.Equal(t, 1, jwks.Len())
|
||||
})
|
||||
|
||||
t.Run("Fails with invalid URL", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, err := s.jwkSetForURL(ctx, "https://bad-url.com")
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
})
|
||||
|
||||
t.Run("Safe for concurrent use", func(t *testing.T) {
|
||||
const concurrency = 20
|
||||
|
||||
// Channel to collect errors
|
||||
errChan := make(chan error, concurrency)
|
||||
|
||||
// Start concurrent requests
|
||||
for range concurrency {
|
||||
go func() {
|
||||
jwks, err := s.jwkSetForURL(t.Context(), url2)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the JWK set is valid
|
||||
if jwks == nil || jwks.Len() != 1 {
|
||||
errChan <- assert.AnError
|
||||
return
|
||||
}
|
||||
|
||||
errChan <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
for range concurrency {
|
||||
assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
const (
|
||||
federatedClientIssuer = "https://external-idp.com"
|
||||
federatedClientAudience = "https://pocket-id.com"
|
||||
federatedClientSubject = "123456abcdef"
|
||||
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
|
||||
)
|
||||
|
||||
var err error
|
||||
// Create a test database
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create two JWKs for testing
|
||||
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
||||
require.NoError(t, err)
|
||||
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a mock HTTP client with custom transport to return the JWKS
|
||||
httpClient := &http.Client{
|
||||
Transport: &MockRoundTripper{
|
||||
Responses: map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Init the OidcService
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
s.jwkCache, err = s.getJWKCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create the test clients
|
||||
// 1. Confidential client
|
||||
confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Confidential Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a client secret for the confidential client
|
||||
confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Public client
|
||||
publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Public Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
IsPublic: true,
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Confidential client with federated identity
|
||||
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Federated Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
Credentials: dto.OidcClientCredentialsDto{
|
||||
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
|
||||
{
|
||||
Issuer: federatedClientIssuer,
|
||||
Audience: federatedClientAudience,
|
||||
Subject: federatedClientSubject,
|
||||
JWKS: federatedClientIssuer + "/jwks.json",
|
||||
},
|
||||
{Issuer: federatedClientIssuerDefaults},
|
||||
},
|
||||
},
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test cases for confidential client (using client secret)
|
||||
t.Run("Confidential client", func(t *testing.T) {
|
||||
t.Run("Succeeds with valid secret", func(t *testing.T) {
|
||||
// Test with valid client credentials
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: confidentialSecret,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, confidentialClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with invalid secret", func(t *testing.T) {
|
||||
// Test with invalid client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: "invalid-secret",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
|
||||
t.Run("Fails with missing secret", func(t *testing.T) {
|
||||
// Test with missing client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: confidentialClient.ID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
})
|
||||
|
||||
// Test cases for public client
|
||||
t.Run("Public client", func(t *testing.T) {
|
||||
t.Run("Succeeds with no credentials", func(t *testing.T) {
|
||||
// Public clients don't require client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: publicClient.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, publicClient.ID, client.ID)
|
||||
})
|
||||
})
|
||||
|
||||
// Test cases for federated client using JWT assertion
|
||||
t.Run("Federated client", func(t *testing.T) {
|
||||
t.Run("Succeeds with valid JWT", func(t *testing.T) {
|
||||
// Create JWT for federated identity
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClientSubject).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with malformed JWT", func(t *testing.T) {
|
||||
// Test with invalid JWT assertion (just a random string)
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: "invalid.jwt.token",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
|
||||
testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
// Populate all claims with valid values
|
||||
builder := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClientSubject).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute))
|
||||
|
||||
// Call builderFn to override the claims
|
||||
builderFn(builder)
|
||||
|
||||
token, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with invalid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
require.Nil(t, client)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Expiration(time.Now().Add(-30 * time.Minute))
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Issuer("https://bad-issuer.com")
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Audience([]string{"bad-audience"})
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Subject("bad-subject")
|
||||
}))
|
||||
|
||||
t.Run("Uses default values for audience and subject", func(t *testing.T) {
|
||||
// Create JWT for federated identity
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuerDefaults).
|
||||
Audience([]string{common.EnvConfig.AppURL}).
|
||||
Subject(federatedClient.ID).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
})
|
||||
})
|
||||
}
|
||||
97
backend/internal/service/testutils_test.go
Normal file
97
backend/internal/service/testutils_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
func newDatabaseForTest(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
// Get a name for this in-memory database that is specific to the test
|
||||
dbName := utils.CreateSha256Hash(t.Name())
|
||||
|
||||
// Connect to a new in-memory SQL database
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
|
||||
&gorm.Config{
|
||||
TranslateError: true,
|
||||
Logger: logger.New(
|
||||
testLoggerAdapter{t: t},
|
||||
logger.Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
ParameterizedQueries: false,
|
||||
Colorful: false,
|
||||
},
|
||||
),
|
||||
})
|
||||
require.NoError(t, err, "Failed to connect to test database")
|
||||
|
||||
// Perform migrations with the embedded migrations
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err, "Failed to get sql.DB")
|
||||
driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{})
|
||||
require.NoError(t, err, "Failed to create migration driver")
|
||||
source, err := iofs.New(resources.FS, "migrations/sqlite")
|
||||
require.NoError(t, err, "Failed to create embedded migration source")
|
||||
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
|
||||
require.NoError(t, err, "Failed to create migration instance")
|
||||
err = m.Up()
|
||||
require.NoError(t, err, "Failed to perform migrations")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// Implements gorm's logger.Writer interface
|
||||
type testLoggerAdapter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
||||
l.t.Logf(format, args...)
|
||||
}
|
||||
|
||||
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
|
||||
type MockRoundTripper struct {
|
||||
Err error
|
||||
Responses map[string]*http.Response
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Check if we have a specific response for this URL
|
||||
for url, resp := range m.Responses {
|
||||
if req.URL.String() == url {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewMockResponse(http.StatusNotFound, ""), nil
|
||||
}
|
||||
|
||||
// NewMockResponse creates an http.Response with the given status code and body
|
||||
func NewMockResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
69
backend/internal/utils/jwk_util.go
Normal file
69
backend/internal/utils/jwk_util.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
)
|
||||
|
||||
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
||||
// It also populates additional fields such as the key ID, usage, and alg.
|
||||
func ImportRawKey(rawKey any) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE oidc_clients DROP COLUMN credentials;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE oidc_clients ADD COLUMN credentials JSONB NULL;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE oidc_clients DROP COLUMN credentials;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE oidc_clients ADD COLUMN credentials TEXT NULL;
|
||||
@@ -348,6 +348,12 @@
|
||||
"the_device_has_been_authorized": "The device has been authorized.",
|
||||
"enter_code_displayed_in_previous_step": "Enter the code that was displayed in the previous step.",
|
||||
"authorize": "Authorize",
|
||||
"federated_identities": "Federated Identities",
|
||||
"federated_identities_description": "Using federated identities, you can authenticate OIDC clients using JWT tokens issued by third-party authorities.",
|
||||
"add_federated_identity": "Add Federated Identity",
|
||||
"add_another_federated_identity": "Add another federated identity",
|
||||
"oidc_allowed_group_count": "Allowed Group Count",
|
||||
"unrestricted": "Unrestricted"
|
||||
"unrestricted": "Unrestricted",
|
||||
"show_advanced_options": "Show Advanced Options",
|
||||
"hide_advanced_options": "Hide Advanced Options"
|
||||
}
|
||||
|
||||
@@ -6,11 +6,23 @@ export type OidcClientMetaData = {
|
||||
hasLogo: boolean;
|
||||
};
|
||||
|
||||
export type OidcClientFederatedIdentity = {
|
||||
issuer: string;
|
||||
subject?: string;
|
||||
audience?: string;
|
||||
jwks?: string;
|
||||
};
|
||||
|
||||
export type OidcClientCredentials = {
|
||||
federatedIdentities: OidcClientFederatedIdentity[];
|
||||
};
|
||||
|
||||
export type OidcClient = OidcClientMetaData & {
|
||||
callbackURLs: string[]; // No longer requires at least one URL
|
||||
logoutCallbackURLs: string[];
|
||||
isPublic: boolean;
|
||||
pkceEnabled: boolean;
|
||||
credentials?: OidcClientCredentials;
|
||||
};
|
||||
|
||||
export type OidcClientWithAllowedUserGroups = OidcClient & {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
import * as Card from '$lib/components/ui/card';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import UserGroupSelection from '$lib/components/user-group-selection.svelte';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import OidcService from '$lib/services/oidc-service';
|
||||
import clientSecretStore from '$lib/stores/client-secret-store';
|
||||
import type { OidcClientCreateWithLogo } from '$lib/types/oidc.type';
|
||||
@@ -16,7 +17,6 @@
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { slide } from 'svelte/transition';
|
||||
import OidcForm from '../oidc-client-form.svelte';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
|
||||
let { data } = $props();
|
||||
let client = $state({
|
||||
@@ -166,7 +166,7 @@
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
<Card.Root>
|
||||
<Card.Content class="p-5">
|
||||
<Card.Content>
|
||||
<OidcForm existingClient={client} callback={updateClient} />
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
<script lang="ts">
|
||||
import FormInput from '$lib/components/form/form-input.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import type { OidcClient, OidcClientFederatedIdentity } from '$lib/types/oidc.type';
|
||||
import { LucideMinus, LucidePlus } from '@lucide/svelte';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
client,
|
||||
federatedIdentities = $bindable([]),
|
||||
error = $bindable(null),
|
||||
...restProps
|
||||
}: HTMLAttributes<HTMLDivElement> & {
|
||||
client?: OidcClient;
|
||||
federatedIdentities: OidcClientFederatedIdentity[];
|
||||
error?: string | null;
|
||||
children?: Snippet;
|
||||
} = $props();
|
||||
|
||||
function addFederatedIdentity() {
|
||||
federatedIdentities = [
|
||||
...federatedIdentities,
|
||||
{
|
||||
issuer: '',
|
||||
subject: '',
|
||||
audience: '',
|
||||
jwks: ''
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
function removeFederatedIdentity(index: number) {
|
||||
federatedIdentities = federatedIdentities.filter((_, i) => i !== index);
|
||||
}
|
||||
|
||||
function updateFederatedIdentity(
|
||||
index: number,
|
||||
field: keyof OidcClientFederatedIdentity,
|
||||
value: string
|
||||
) {
|
||||
federatedIdentities[index] = {
|
||||
...federatedIdentities[index],
|
||||
[field]: value
|
||||
};
|
||||
}
|
||||
</script>
|
||||
|
||||
<div {...restProps}>
|
||||
<FormInput label={m.federated_identities()} description={m.federated_identities_description()}>
|
||||
<div class="space-y-4">
|
||||
{#each federatedIdentities as identity, i}
|
||||
<div class="space-y-3 rounded-lg border p-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<Label class="text-sm font-medium">Identity {i + 1}</Label>
|
||||
{#if federatedIdentities.length > 0}
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onclick={() => removeFederatedIdentity(i)}
|
||||
aria-label="Remove federated identity"
|
||||
>
|
||||
<LucideMinus class="size-4" />
|
||||
</Button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-2">
|
||||
<div>
|
||||
<Label for="issuer-{i}" class="text-xs">Issuer (Required)</Label>
|
||||
<Input
|
||||
id="issuer-{i}"
|
||||
placeholder="https://example.com/"
|
||||
value={identity.issuer}
|
||||
oninput={(e) => updateFederatedIdentity(i, 'issuer', e.currentTarget.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="subject-{i}" class="text-xs">Subject (Optional)</Label>
|
||||
<Input
|
||||
id="subject-{i}"
|
||||
placeholder="Defaults to the client ID: {client?.id}"
|
||||
value={identity.subject || ''}
|
||||
oninput={(e) => updateFederatedIdentity(i, 'subject', e.currentTarget.value)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="audience-{i}" class="text-xs">Audience (Optional)</Label>
|
||||
<Input
|
||||
id="audience-{i}"
|
||||
placeholder="Defaults to the Pocket ID URL"
|
||||
value={identity.audience || ''}
|
||||
oninput={(e) => updateFederatedIdentity(i, 'audience', e.currentTarget.value)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="jwks-{i}" class="text-xs">JWKS URL (Optional)</Label>
|
||||
<Input
|
||||
id="jwks-{i}"
|
||||
placeholder="Defaults to {identity.issuer || '<issuer>'}/.well-known/jwks.json"
|
||||
value={identity.jwks || ''}
|
||||
oninput={(e) => updateFederatedIdentity(i, 'jwks', e.currentTarget.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</FormInput>
|
||||
|
||||
{#if error}
|
||||
<p class="text-destructive mt-1 text-xs">{error}</p>
|
||||
{/if}
|
||||
|
||||
<Button class="mt-3" variant="secondary" size="sm" onclick={addFederatedIdentity} type="button">
|
||||
<LucidePlus class="mr-1 size-4" />
|
||||
{federatedIdentities.length === 0
|
||||
? m.add_federated_identity()
|
||||
: m.add_another_federated_identity()}
|
||||
</Button>
|
||||
</div>
|
||||
@@ -5,14 +5,14 @@
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import type {
|
||||
OidcClient,
|
||||
OidcClientCreate,
|
||||
OidcClientCreateWithLogo
|
||||
} from '$lib/types/oidc.type';
|
||||
import type { OidcClient, OidcClientCreateWithLogo } from '$lib/types/oidc.type';
|
||||
import { preventDefault } from '$lib/utils/event-util';
|
||||
import { createForm } from '$lib/utils/form-util';
|
||||
import { cn } from '$lib/utils/style';
|
||||
import { LucideChevronDown } from '@lucide/svelte';
|
||||
import { slide } from 'svelte/transition';
|
||||
import { z } from 'zod';
|
||||
import FederatedIdentitiesInput from './federated-identities-input.svelte';
|
||||
import OidcCallbackUrlInput from './oidc-callback-url-input.svelte';
|
||||
|
||||
let {
|
||||
@@ -24,17 +24,21 @@
|
||||
} = $props();
|
||||
|
||||
let isLoading = $state(false);
|
||||
let showAdvancedOptions = $state(false);
|
||||
let logo = $state<File | null | undefined>();
|
||||
let logoDataURL: string | null = $state(
|
||||
existingClient?.hasLogo ? `/api/oidc/clients/${existingClient!.id}/logo` : null
|
||||
);
|
||||
|
||||
const client: OidcClientCreate = {
|
||||
const client = {
|
||||
name: existingClient?.name || '',
|
||||
callbackURLs: existingClient?.callbackURLs || [],
|
||||
logoutCallbackURLs: existingClient?.logoutCallbackURLs || [],
|
||||
isPublic: existingClient?.isPublic || false,
|
||||
pkceEnabled: existingClient?.pkceEnabled || false
|
||||
pkceEnabled: existingClient?.pkceEnabled || false,
|
||||
credentials: {
|
||||
federatedIdentities: existingClient?.credentials?.federatedIdentities || []
|
||||
}
|
||||
};
|
||||
|
||||
const formSchema = z.object({
|
||||
@@ -42,7 +46,17 @@
|
||||
callbackURLs: z.array(z.string().nonempty()).default([]),
|
||||
logoutCallbackURLs: z.array(z.string().nonempty()),
|
||||
isPublic: z.boolean(),
|
||||
pkceEnabled: z.boolean()
|
||||
pkceEnabled: z.boolean(),
|
||||
credentials: z.object({
|
||||
federatedIdentities: z.array(
|
||||
z.object({
|
||||
issuer: z.string().url(),
|
||||
subject: z.string().optional(),
|
||||
audience: z.string().optional(),
|
||||
jwks: z.string().url().optional().or(z.literal(''))
|
||||
})
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
type FormSchema = typeof formSchema;
|
||||
@@ -139,8 +153,31 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="w-full"></div>
|
||||
<div class="mt-5 flex justify-end">
|
||||
<Button {isLoading} type="submit">{m.save()}</Button>
|
||||
|
||||
{#if showAdvancedOptions}
|
||||
<div class="mt-5 md:col-span-2" transition:slide={{ duration: 200 }}>
|
||||
<FederatedIdentitiesInput
|
||||
client={existingClient}
|
||||
bind:federatedIdentities={$inputs.credentials.value.federatedIdentities}
|
||||
bind:error={$inputs.credentials.error}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="relative mt-5 flex justify-center">
|
||||
<Button
|
||||
variant="ghost"
|
||||
class="text-muted-foregroun"
|
||||
onclick={() => (showAdvancedOptions = !showAdvancedOptions)}
|
||||
>
|
||||
{showAdvancedOptions ? m.hide_advanced_options() : m.show_advanced_options()}
|
||||
<LucideChevronDown
|
||||
class={cn(
|
||||
'size-5 transition-transform duration-200',
|
||||
showAdvancedOptions && 'rotate-180 transform'
|
||||
)}
|
||||
/>
|
||||
</Button>
|
||||
<Button {isLoading} type="submit" class="absolute right-0">{m.save()}</Button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
@@ -35,6 +35,17 @@ export const oidcClients = {
|
||||
callbackUrl: 'http://immich/auth/callback',
|
||||
secret: 'PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x'
|
||||
},
|
||||
federated: {
|
||||
id: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
|
||||
name: 'Federated',
|
||||
callbackUrl: 'http://federated/auth/callback',
|
||||
federatedJWT: {
|
||||
issuer: 'https://external-idp.local',
|
||||
audience: 'api://PocketID',
|
||||
subject: 'c48232ff-ff65-45ed-ae96-7afa8a9b443b',
|
||||
},
|
||||
accessCodes: ['federated']
|
||||
},
|
||||
pingvinShare: {
|
||||
name: 'Pingvin Share',
|
||||
callbackUrl: 'http://pingvin.share/auth/callback',
|
||||
|
||||
14
tests/package-lock.json
generated
14
tests/package-lock.json
generated
@@ -7,6 +7,7 @@
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.52.0",
|
||||
"@types/node": "^22.15.21",
|
||||
"dotenv": "^16.5.0",
|
||||
"jose": "^6.0.11"
|
||||
}
|
||||
},
|
||||
@@ -36,6 +37,19 @@
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
},
|
||||
"node_modules/dotenv": {
|
||||
"version": "16.5.0",
|
||||
"resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.5.0.tgz",
|
||||
"integrity": "sha512-m/C+AwOAr9/W1UOIZUo232ejMNnJAJtYQjUbHoNTBNTJSvqzzDh7vnrei3o3r3m9blf6ZoDkvcw0VmozNRFJxg==",
|
||||
"dev": true,
|
||||
"license": "BSD-2-Clause",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://dotenvx.com"
|
||||
}
|
||||
},
|
||||
"node_modules/fsevents": {
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.52.0",
|
||||
"@types/node": "^22.15.21",
|
||||
"jose": "^6.0.11"
|
||||
"jose": "^6.0.11",
|
||||
"dotenv": "^16.5.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
import { defineConfig, devices } from '@playwright/test';
|
||||
import { defineConfig, devices } from "@playwright/test";
|
||||
import "dotenv/config";
|
||||
|
||||
/**
|
||||
* See https://playwright.dev/docs/test-configuration.
|
||||
*/
|
||||
export default defineConfig({
|
||||
outputDir: './.output',
|
||||
timeout: 10000,
|
||||
testDir: './specs',
|
||||
fullyParallel: false,
|
||||
forbidOnly: !!process.env.CI,
|
||||
retries: process.env.CI ? 1 : 0,
|
||||
workers: 1,
|
||||
reporter: process.env.CI
|
||||
? [['html', { outputFolder: '.report' }], ['github']]
|
||||
: [['line'], ['html', { open: 'never', outputFolder: '.report' }]],
|
||||
use: {
|
||||
baseURL: process.env.APP_URL ?? 'http://localhost:1411',
|
||||
video: 'retain-on-failure',
|
||||
trace: 'on-first-retry'
|
||||
},
|
||||
projects: [
|
||||
{ name: 'setup', testMatch: /.*\.setup\.ts/ },
|
||||
{
|
||||
name: 'chromium',
|
||||
use: { ...devices['Desktop Chrome'], storageState: '.auth/user.json' },
|
||||
dependencies: ['setup']
|
||||
}
|
||||
]
|
||||
outputDir: "./.output",
|
||||
timeout: 10000,
|
||||
testDir: "./specs",
|
||||
fullyParallel: false,
|
||||
forbidOnly: !!process.env.CI,
|
||||
retries: process.env.CI ? 1 : 0,
|
||||
workers: 1,
|
||||
reporter: process.env.CI
|
||||
? [["html", { outputFolder: ".report" }], ["github"]]
|
||||
: [["line"], ["html", { open: "never", outputFolder: ".report" }]],
|
||||
use: {
|
||||
baseURL: process.env.APP_URL ?? "http://localhost:1411",
|
||||
video: "retain-on-failure",
|
||||
trace: "on-first-retry",
|
||||
},
|
||||
projects: [
|
||||
{ name: "setup", testMatch: /.*\.setup\.ts/ },
|
||||
{
|
||||
name: "chromium",
|
||||
use: { ...devices["Desktop Chrome"], storageState: ".auth/user.json" },
|
||||
dependencies: ["setup"],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -4,6 +4,8 @@ import { cleanupBackend } from '../utils/cleanup.util';
|
||||
test.beforeEach(cleanupBackend);
|
||||
|
||||
test.describe('LDAP Integration', () => {
|
||||
test.skip(process.env.SKIP_LDAP_TESTS === "true", 'Skipping LDAP tests due to SKIP_LDAP_TESTS environment variable');
|
||||
|
||||
test('LDAP configuration is working properly', async ({ page }) => {
|
||||
await page.goto('/settings/admin/application-configuration');
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import test, { expect } from "@playwright/test";
|
||||
import { oidcClients, refreshTokens, users } from "../data";
|
||||
import { cleanupBackend } from "../utils/cleanup.util";
|
||||
import { generateIdToken, generateOauthAccessToken } from "../utils/jwt.util";
|
||||
import oidcUtil from "../utils/oidc.util";
|
||||
import * as oidcUtil from "../utils/oidc.util";
|
||||
import passkeyUtil from "../utils/passkey.util";
|
||||
|
||||
test.beforeEach(cleanupBackend);
|
||||
@@ -449,3 +449,40 @@ test("Authorize new client with device authorization with user group not allowed
|
||||
.filter({ hasText: "You're not allowed to access this service." })
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Federated identity fails with invalid client assertion", async ({
|
||||
page,
|
||||
}) => {
|
||||
const client = oidcClients.federated;
|
||||
|
||||
const res = await oidcUtil.exchangeCode(page, {
|
||||
client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
|
||||
grant_type: 'authorization_code',
|
||||
redirect_uri: client.callbackUrl,
|
||||
code: client.accessCodes[0],
|
||||
client_id: client.id,
|
||||
client_assertion:'not-an-assertion',
|
||||
});
|
||||
|
||||
expect(res?.error).toBe('Invalid client assertion');
|
||||
});
|
||||
|
||||
test("Authorize existing client with federated identity", async ({
|
||||
page,
|
||||
}) => {
|
||||
const client = oidcClients.federated;
|
||||
const clientAssertion = await oidcUtil.getClientAssertion(page, client.federatedJWT);
|
||||
|
||||
const res = await oidcUtil.exchangeCode(page, {
|
||||
client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
|
||||
grant_type: 'authorization_code',
|
||||
redirect_uri: client.callbackUrl,
|
||||
code: client.accessCodes[0],
|
||||
client_id: client.id,
|
||||
client_assertion: clientAssertion,
|
||||
});
|
||||
|
||||
expect(res.access_token).not.toBeNull;
|
||||
expect(res.expires_in).not.toBeNull;
|
||||
expect(res.token_type).toBe('Bearer');
|
||||
});
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"baseUrl": "."
|
||||
"baseUrl": ".",
|
||||
"lib": ["ES2022"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import playwrightConfig from "../playwright.config";
|
||||
|
||||
export async function cleanupBackend() {
|
||||
const response = await fetch(
|
||||
playwrightConfig.use!.baseURL + "/api/test/reset",
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
const url = new URL("/api/test/reset", playwrightConfig.use!.baseURL);
|
||||
|
||||
if (process.env.SKIP_LDAP_TESTS === "true") {
|
||||
url.searchParams.append("skip-ldap", "true");
|
||||
}
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { Page } from '@playwright/test';
|
||||
|
||||
async function getUserCode(page: Page, clientId: string, clientSecret: string) {
|
||||
const response = await page.request
|
||||
export async function getUserCode(page: Page, clientId: string, clientSecret: string): Promise<string> {
|
||||
return page.request
|
||||
.post('/api/oidc/device/authorize', {
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded'
|
||||
@@ -12,11 +12,29 @@ async function getUserCode(page: Page, clientId: string, clientSecret: string) {
|
||||
scope: 'openid profile email'
|
||||
}
|
||||
})
|
||||
.then((r) => r.json());
|
||||
|
||||
return response.user_code;
|
||||
.then((r) => r.json())
|
||||
.then((r) => r.user_code);
|
||||
}
|
||||
|
||||
export default {
|
||||
getUserCode
|
||||
};
|
||||
export async function exchangeCode(page: Page, params: Record<string,string>): Promise<{access_token?: string, token_type?: string, expires_in?: number, error?: string}> {
|
||||
return page.request
|
||||
.post('/api/oidc/token', {
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded'
|
||||
},
|
||||
form: params,
|
||||
})
|
||||
.then((r) => r.json());
|
||||
}
|
||||
|
||||
export async function getClientAssertion(page: Page, data: {issuer: string, audience: string, subject: string}): Promise<string> {
|
||||
return page.request
|
||||
.post('/api/externalidp/sign', {
|
||||
data: {
|
||||
iss: data.issuer,
|
||||
aud: data.audience,
|
||||
sub: data.subject,
|
||||
},
|
||||
})
|
||||
.then((r) => r.text());
|
||||
}
|
||||
Reference in New Issue
Block a user