mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-23 09:15:13 +03:00
feat: add support for ECDSA and EdDSA keys (#359)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
5c198c280c
commit
96876a99c5
20
.github/workflows/e2e-tests.yml
vendored
20
.github/workflows/e2e-tests.yml
vendored
@@ -69,6 +69,8 @@ jobs:
|
|||||||
-e APP_ENV=test \
|
-e APP_ENV=test \
|
||||||
pocket-id/pocket-id:test
|
pocket-id/pocket-id:test
|
||||||
|
|
||||||
|
docker logs -f pocket-id-sqlite &> /tmp/backend.log &
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
working-directory: ./frontend
|
working-directory: ./frontend
|
||||||
run: npx playwright test
|
run: npx playwright test
|
||||||
@@ -81,6 +83,14 @@ jobs:
|
|||||||
include-hidden-files: true
|
include-hidden-files: true
|
||||||
retention-days: 15
|
retention-days: 15
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: backend-sqlite
|
||||||
|
path: /tmp/backend.log
|
||||||
|
include-hidden-files: true
|
||||||
|
retention-days: 15
|
||||||
|
|
||||||
test-postgres:
|
test-postgres:
|
||||||
if: github.event.pull_request.head.ref != 'i18n_crowdin'
|
if: github.event.pull_request.head.ref != 'i18n_crowdin'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -143,6 +153,8 @@ jobs:
|
|||||||
-e POSTGRES_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \
|
-e POSTGRES_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \
|
||||||
pocket-id/pocket-id:test
|
pocket-id/pocket-id:test
|
||||||
|
|
||||||
|
docker logs -f pocket-id-postgres &> /tmp/backend.log &
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
working-directory: ./frontend
|
working-directory: ./frontend
|
||||||
run: npx playwright test
|
run: npx playwright test
|
||||||
@@ -154,3 +166,11 @@ jobs:
|
|||||||
path: frontend/tests/.report
|
path: frontend/tests/.report
|
||||||
include-hidden-files: true
|
include-hidden-files: true
|
||||||
retention-days: 15
|
retention-days: 15
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: backend-postgres
|
||||||
|
path: /tmp/backend.log
|
||||||
|
include-hidden-files: true
|
||||||
|
retention-days: 15
|
||||||
|
|||||||
@@ -14,11 +14,10 @@ require (
|
|||||||
github.com/go-ldap/ldap/v3 v3.4.10
|
github.com/go-ldap/ldap/v3 v3.4.10
|
||||||
github.com/go-playground/validator/v10 v10.24.0
|
github.com/go-playground/validator/v10 v10.24.0
|
||||||
github.com/go-webauthn/webauthn v0.11.2
|
github.com/go-webauthn/webauthn v0.11.2
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
|
||||||
github.com/golang-migrate/migrate/v4 v4.18.2
|
github.com/golang-migrate/migrate/v4 v4.18.2
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3
|
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
|
||||||
github.com/mileusna/useragent v1.3.5
|
github.com/mileusna/useragent v1.3.5
|
||||||
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
|
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
@@ -45,6 +44,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-webauthn/x v0.1.16 // indirect
|
github.com/go-webauthn/x v0.1.16 // indirect
|
||||||
github.com/goccy/go-json v0.10.4 // indirect
|
github.com/goccy/go-json v0.10.4 // indirect
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
|
||||||
github.com/google/go-tpm v0.9.3 // indirect
|
github.com/google/go-tpm v0.9.3 // indirect
|
||||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||||
|
|||||||
@@ -145,8 +145,8 @@ github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZ
|
|||||||
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
||||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
|
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
|
||||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
|
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
|
||||||
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3 h1:HHT8iW+UcPBgBr5A3soZQQsL5cBor/u6BkLB+wzY/R0=
|
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
|
||||||
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
|
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
|
||||||
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
||||||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
@@ -143,7 +144,7 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
|
|||||||
// @Success 200 {file} binary "Logo image"
|
// @Success 200 {file} binary "Logo image"
|
||||||
// @Router /api/application-configuration/logo [get]
|
// @Router /api/application-configuration/logo [get]
|
||||||
func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
|
func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
|
||||||
lightLogo := c.DefaultQuery("light", "true") == "true"
|
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
|
||||||
|
|
||||||
var imageName string
|
var imageName string
|
||||||
var imageType string
|
var imageType string
|
||||||
@@ -196,7 +197,7 @@ func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
|
|||||||
// @Security BearerAuth
|
// @Security BearerAuth
|
||||||
// @Router /api/application-configuration/logo [put]
|
// @Router /api/application-configuration/logo [put]
|
||||||
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
||||||
lightLogo := c.DefaultQuery("light", "true") == "true"
|
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
|
||||||
|
|
||||||
var imageName string
|
var imageName string
|
||||||
var imageType string
|
var imageType string
|
||||||
|
|||||||
@@ -195,22 +195,28 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
|||||||
// @Security OAuth2AccessToken
|
// @Security OAuth2AccessToken
|
||||||
// @Router /api/oidc/userinfo [get]
|
// @Router /api/oidc/userinfo [get]
|
||||||
func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||||
authHeaderSplit := strings.Split(c.GetHeader("Authorization"), " ")
|
_, authToken, ok := strings.Cut(c.GetHeader("Authorization"), " ")
|
||||||
if len(authHeaderSplit) != 2 {
|
if !ok || authToken == "" {
|
||||||
_ = c.Error(&common.MissingAccessToken{})
|
_ = c.Error(&common.MissingAccessToken{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token := authHeaderSplit[1]
|
token, err := oc.jwtService.VerifyOauthAccessToken(authToken)
|
||||||
|
|
||||||
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userID := jwtClaims.Subject
|
userID, ok := token.Subject()
|
||||||
clientId := jwtClaims.Audience[0]
|
if !ok {
|
||||||
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
|
_ = c.Error(&common.TokenInvalidError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientID, ok := token.Audience()
|
||||||
|
if !ok || len(clientID) != 1 {
|
||||||
|
_ = c.Error(&common.TokenInvalidError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientID[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
||||||
@@ -228,7 +227,7 @@ func (uc *UserController) updateUserHandler(c *gin.Context) {
|
|||||||
// @Success 200 {object} dto.UserDto
|
// @Success 200 {object} dto.UserDto
|
||||||
// @Router /api/users/me [put]
|
// @Router /api/users/me [put]
|
||||||
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
|
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
|
||||||
if uc.appConfigService.DbConfig.AllowOwnAccountEdit.Value != "true" {
|
if !uc.appConfigService.DbConfig.AllowOwnAccountEdit.IsTrue() {
|
||||||
_ = c.Error(&common.AccountEditNotAllowedError{})
|
_ = c.Error(&common.AccountEditNotAllowedError{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -391,8 +390,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDurationInMinutesParsed, _ := strconv.Atoi(uc.appConfigService.DbConfig.SessionDuration.Value)
|
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds())
|
||||||
maxAge := sessionDurationInMinutesParsed * 60
|
|
||||||
cookie.AddAccessTokenCookie(c, maxAge, token)
|
cookie.AddAccessTokenCookie(c, maxAge, token)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, userDto)
|
c.JSON(http.StatusOK, userDto)
|
||||||
@@ -417,8 +415,7 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDurationInMinutesParsed, _ := strconv.Atoi(uc.appConfigService.DbConfig.SessionDuration.Value)
|
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds())
|
||||||
maxAge := sessionDurationInMinutesParsed * 60
|
|
||||||
cookie.AddAccessTokenCookie(c, maxAge, token)
|
cookie.AddAccessTokenCookie(c, maxAge, token)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, userDto)
|
c.JSON(http.StatusOK, userDto)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-webauthn/webauthn/protocol"
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
@@ -107,8 +106,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDurationInMinutesParsed, _ := strconv.Atoi(wc.appConfigService.DbConfig.SessionDuration.Value)
|
maxAge := int(wc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds())
|
||||||
maxAge := sessionDurationInMinutesParsed * 60
|
|
||||||
cookie.AddAccessTokenCookie(c, maxAge, token)
|
cookie.AddAccessTokenCookie(c, maxAge, token)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, userDto)
|
c.JSON(http.StatusOK, userDto)
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
)
|
)
|
||||||
@@ -14,12 +18,21 @@ import (
|
|||||||
// @Tags Well Known
|
// @Tags Well Known
|
||||||
func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) {
|
func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) {
|
||||||
wkc := &WellKnownController{jwtService: jwtService}
|
wkc := &WellKnownController{jwtService: jwtService}
|
||||||
|
|
||||||
|
// Pre-compute the OIDC configuration document, which is static
|
||||||
|
var err error
|
||||||
|
wkc.oidcConfig, err = wkc.computeOIDCConfiguration()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to pre-compute OpenID Connect configuration document: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
|
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
|
||||||
group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler)
|
group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
type WellKnownController struct {
|
type WellKnownController struct {
|
||||||
jwtService *service.JwtService
|
jwtService *service.JwtService
|
||||||
|
oidcConfig []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// jwksHandler godoc
|
// jwksHandler godoc
|
||||||
@@ -46,8 +59,16 @@ func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
|
|||||||
// @Success 200 {object} object "OpenID Connect configuration"
|
// @Success 200 {object} object "OpenID Connect configuration"
|
||||||
// @Router /.well-known/openid-configuration [get]
|
// @Router /.well-known/openid-configuration [get]
|
||||||
func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
|
func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
|
||||||
|
c.Data(http.StatusOK, "application/json; charset=utf-8", wkc.oidcConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
||||||
appUrl := common.EnvConfig.AppURL
|
appUrl := common.EnvConfig.AppURL
|
||||||
config := map[string]interface{}{
|
alg, err := wkc.jwtService.GetKeyAlg()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get key algorithm: %w", err)
|
||||||
|
}
|
||||||
|
config := map[string]any{
|
||||||
"issuer": appUrl,
|
"issuer": appUrl,
|
||||||
"authorization_endpoint": appUrl + "/authorize",
|
"authorization_endpoint": appUrl + "/authorize",
|
||||||
"token_endpoint": appUrl + "/api/oidc/token",
|
"token_endpoint": appUrl + "/api/oidc/token",
|
||||||
@@ -59,7 +80,7 @@ func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
|
|||||||
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
||||||
"response_types_supported": []string{"code", "id_token"},
|
"response_types_supported": []string{"code", "id_token"},
|
||||||
"subject_types_supported": []string{"public"},
|
"subject_types_supported": []string{"public"},
|
||||||
"id_token_signing_alg_values_supported": []string{"RS256"},
|
"id_token_signing_alg_values_supported": []string{alg.String()},
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, config)
|
return json.Marshal(config)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func RegisterLdapJobs(ldapService *service.LdapService, appConfigService *servic
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (j *LdapJobs) syncLdap() error {
|
func (j *LdapJobs) syncLdap() error {
|
||||||
if j.appConfigService.DbConfig.LdapEnabled.Value == "true" {
|
if j.appConfigService.DbConfig.LdapEnabled.IsTrue() {
|
||||||
return j.ldapService.SyncAll()
|
return j.ldapService.SyncAll()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware {
|
|||||||
|
|
||||||
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
|
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
|
||||||
userID, isAdmin, err := m.Verify(c, adminRequired)
|
userID, isAdmin, err := m.Verify(c, adminRequired)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@@ -33,27 +32,37 @@ func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) {
|
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, err error) {
|
||||||
// Extract the token from the cookie
|
// Extract the token from the cookie
|
||||||
token, err := c.Cookie(cookie.AccessTokenCookieName)
|
accessToken, err := c.Cookie(cookie.AccessTokenCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Try to extract the token from the Authorization header if it's not in the cookie
|
// Try to extract the token from the Authorization header if it's not in the cookie
|
||||||
authorizationHeaderSplit := strings.Split(c.GetHeader("Authorization"), " ")
|
var ok bool
|
||||||
if len(authorizationHeaderSplit) != 2 {
|
_, accessToken, ok = strings.Cut(c.GetHeader("Authorization"), " ")
|
||||||
|
if !ok || accessToken == "" {
|
||||||
return "", false, &common.NotSignedInError{}
|
return "", false, &common.NotSignedInError{}
|
||||||
}
|
}
|
||||||
token = authorizationHeaderSplit[1]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := m.jwtService.VerifyAccessToken(token)
|
token, err := m.jwtService.VerifyAccessToken(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", false, &common.NotSignedInError{}
|
return "", false, &common.NotSignedInError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
subject, ok := token.Subject()
|
||||||
|
if !ok {
|
||||||
|
_ = c.Error(&common.TokenInvalidError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the user is an admin
|
// Check if the user is an admin
|
||||||
if adminRequired && !claims.IsAdmin {
|
isAdmin, err = service.GetIsAdmin(token)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, &common.TokenInvalidError{}
|
||||||
|
}
|
||||||
|
if adminRequired && !isAdmin {
|
||||||
return "", false, &common.MissingPermissionError{}
|
return "", false, &common.MissingPermissionError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims.Subject, claims.IsAdmin, nil
|
return subject, isAdmin, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AppConfigVariable struct {
|
type AppConfigVariable struct {
|
||||||
@@ -13,11 +14,21 @@ type AppConfigVariable struct {
|
|||||||
DefaultValue string
|
DefaultValue string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsTrue returns true if the value is a truthy string, such as "true", "t", "yes", "1", etc.
|
||||||
func (a *AppConfigVariable) IsTrue() bool {
|
func (a *AppConfigVariable) IsTrue() bool {
|
||||||
ok, _ := strconv.ParseBool(a.Value)
|
ok, _ := strconv.ParseBool(a.Value)
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AsDurationMinutes returns the value as a time.Duration, interpreting the string as a whole number of minutes.
|
||||||
|
func (a *AppConfigVariable) AsDurationMinutes() time.Duration {
|
||||||
|
val, err := strconv.Atoi(a.Value)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return time.Duration(val) * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
// General
|
// General
|
||||||
AppName AppConfigVariable
|
AppName AppConfigVariable
|
||||||
|
|||||||
60
backend/internal/model/app_config_test.go
Normal file
60
backend/internal/model/app_config_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
expected time.Duration
|
||||||
|
expectedSeconds int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid positive integer",
|
||||||
|
value: "60",
|
||||||
|
expected: 60 * time.Minute,
|
||||||
|
expectedSeconds: 3600,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid zero integer",
|
||||||
|
value: "0",
|
||||||
|
expected: 0,
|
||||||
|
expectedSeconds: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative integer",
|
||||||
|
value: "-30",
|
||||||
|
expected: -30 * time.Minute,
|
||||||
|
expectedSeconds: -1800,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid non-integer",
|
||||||
|
value: "not-a-number",
|
||||||
|
expected: 0,
|
||||||
|
expectedSeconds: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
value: "",
|
||||||
|
expected: 0,
|
||||||
|
expectedSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
configVar := AppConfigVariable{
|
||||||
|
Value: tt.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := configVar.AsDurationMinutes()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
assert.Equal(t, tt.expectedSeconds, int(result.Seconds()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -60,7 +60,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
|
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
|
||||||
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.Value == "true" && count <= 1 {
|
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
|
||||||
go func() {
|
go func() {
|
||||||
var user model.User
|
var user model.User
|
||||||
s.db.Where("id = ?", userID).First(&user)
|
s.db.Where("id = ?", userID).First(&user)
|
||||||
|
|||||||
@@ -5,13 +5,6 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/emersion/go-sasl"
|
|
||||||
"github.com/emersion/go-smtp"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
htemplate "html/template"
|
htemplate "html/template"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"mime/quotedprintable"
|
"mime/quotedprintable"
|
||||||
@@ -20,6 +13,14 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
ttemplate "text/template"
|
ttemplate "text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/emersion/go-sasl"
|
||||||
|
"github.com/emersion/go-smtp"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EmailService struct {
|
type EmailService struct {
|
||||||
|
|||||||
@@ -11,13 +11,11 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
@@ -34,6 +32,13 @@ const (
|
|||||||
|
|
||||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||||
KeyUsageSigning = "sig"
|
KeyUsageSigning = "sig"
|
||||||
|
|
||||||
|
// IsAdminClaim is a boolean claim used in access tokens for admin users
|
||||||
|
// This may be omitted on non-admin tokens
|
||||||
|
IsAdminClaim = "isAdmin"
|
||||||
|
|
||||||
|
// Acceptable clock skew for verifying tokens
|
||||||
|
clockSkew = time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
type JwtService struct {
|
type JwtService struct {
|
||||||
@@ -61,11 +66,6 @@ func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) e
|
|||||||
return s.loadOrGenerateKey(keysPath)
|
return s.loadOrGenerateKey(keysPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccessTokenJWTClaims struct {
|
|
||||||
jwt.RegisteredClaims
|
|
||||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
|
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
|
||||||
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
||||||
var key jwk.Key
|
var key jwk.Key
|
||||||
@@ -170,133 +170,164 @@ func (s *JwtService) SetKey(privateKey jwk.Key) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||||
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
|
now := time.Now()
|
||||||
claim := AccessTokenJWTClaims{
|
token, err := jwt.NewBuilder().
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
Subject(user.ID).
|
||||||
Subject: user.ID,
|
Expiration(now.Add(s.appConfigService.DbConfig.SessionDuration.AsDurationMinutes())).
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
|
IssuedAt(now).
|
||||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
Issuer(common.EnvConfig.AppURL).
|
||||||
Audience: jwt.ClaimStrings{common.EnvConfig.AppURL},
|
Build()
|
||||||
},
|
|
||||||
IsAdmin: user.IsAdmin,
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
|
||||||
token.Header["kid"] = s.keyId
|
|
||||||
|
|
||||||
var privateKeyRaw any
|
|
||||||
err := jwk.Export(s.privateKey, &privateKeyRaw)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to export private key object: %w", err)
|
return "", fmt.Errorf("failed to build token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
signed, err := token.SignedString(privateKeyRaw)
|
err = SetAudienceString(token, common.EnvConfig.AppURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = SetIsAdmin(token, user.IsAdmin)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
alg, _ := s.privateKey.Algorithm()
|
||||||
|
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return signed, nil
|
return string(signed), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
|
func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (any, error) {
|
alg, _ := s.privateKey.Algorithm()
|
||||||
return s.getPublicKeyRaw()
|
token, err := jwt.ParseString(
|
||||||
})
|
tokenString,
|
||||||
if err != nil || !token.Valid {
|
jwt.WithValidate(true),
|
||||||
return nil, errors.New("couldn't handle this token")
|
jwt.WithKey(alg, s.privateKey),
|
||||||
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
|
jwt.WithAudience(common.EnvConfig.AppURL),
|
||||||
|
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, isValid := token.Claims.(*AccessTokenJWTClaims)
|
return token, nil
|
||||||
if !isValid {
|
|
||||||
return nil, errors.New("can't parse claims")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !slices.Contains(claims.Audience, common.EnvConfig.AppURL) {
|
|
||||||
return nil, errors.New("audience doesn't match")
|
|
||||||
}
|
|
||||||
return claims, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
|
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
|
||||||
// Initialize with capacity for userClaims, + 4 fixed claims, + 2 claims which may be set in some cases, to avoid re-allocations
|
now := time.Now()
|
||||||
claims := make(jwt.MapClaims, len(userClaims)+6)
|
token, err := jwt.NewBuilder().
|
||||||
claims["aud"] = clientID
|
Expiration(now.Add(1 * time.Hour)).
|
||||||
claims["exp"] = jwt.NewNumericDate(time.Now().Add(1 * time.Hour))
|
IssuedAt(now).
|
||||||
claims["iat"] = jwt.NewNumericDate(time.Now())
|
Issuer(common.EnvConfig.AppURL).
|
||||||
claims["iss"] = common.EnvConfig.AppURL
|
Build()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to build token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = SetAudienceString(token, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range userClaims {
|
for k, v := range userClaims {
|
||||||
claims[k] = v
|
err = token.Set(k, v)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if nonce != "" {
|
if nonce != "" {
|
||||||
claims["nonce"] = nonce
|
err = token.Set("nonce", nonce)
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
||||||
token.Header["kid"] = s.keyId
|
|
||||||
|
|
||||||
var privateKeyRaw any
|
|
||||||
err := jwk.Export(s.privateKey, &privateKeyRaw)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to export private key object: %w", err)
|
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return token.SignedString(privateKeyRaw)
|
alg, _ := s.privateKey.Algorithm()
|
||||||
|
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(signed), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) (jwt.Token, error) {
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
|
alg, _ := s.privateKey.Algorithm()
|
||||||
return s.getPublicKeyRaw()
|
|
||||||
}, jwt.WithIssuer(common.EnvConfig.AppURL))
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
|
opts := make([]jwt.ParseOption, 0)
|
||||||
return nil, errors.New("couldn't handle this token")
|
|
||||||
|
// These options are always present
|
||||||
|
opts = append(opts,
|
||||||
|
jwt.WithValidate(true),
|
||||||
|
jwt.WithKey(alg, s.privateKey),
|
||||||
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
|
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||||
|
)
|
||||||
|
|
||||||
|
// By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp"
|
||||||
|
// In case we want to accept expired tokens (during logout), we need to set the validators explicitly without validating "exp"
|
||||||
|
if acceptExpiredTokens {
|
||||||
|
// This is equivalent to the default validators except it doesn't validate "exp"
|
||||||
|
opts = append(opts,
|
||||||
|
jwt.WithResetValidators(true),
|
||||||
|
jwt.WithValidator(jwt.IsIssuedAtValid()),
|
||||||
|
jwt.WithValidator(jwt.IsNbfValid()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
|
token, err := jwt.ParseString(tokenString, opts...)
|
||||||
if !isValid {
|
if err != nil {
|
||||||
return nil, errors.New("can't parse claims")
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||||
claim := jwt.RegisteredClaims{
|
now := time.Now()
|
||||||
Subject: user.ID,
|
token, err := jwt.NewBuilder().
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
Subject(user.ID).
|
||||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
Expiration(now.Add(1 * time.Hour)).
|
||||||
Audience: jwt.ClaimStrings{clientID},
|
IssuedAt(now).
|
||||||
Issuer: common.EnvConfig.AppURL,
|
Issuer(common.EnvConfig.AppURL).
|
||||||
}
|
Build()
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
|
||||||
token.Header["kid"] = s.keyId
|
|
||||||
|
|
||||||
var privateKeyRaw any
|
|
||||||
err := jwk.Export(s.privateKey, &privateKeyRaw)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to export private key object: %w", err)
|
return "", fmt.Errorf("failed to build token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return token.SignedString(privateKeyRaw)
|
err = SetAudienceString(token, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
alg, _ := s.privateKey.Algorithm()
|
||||||
|
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(signed), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
|
alg, _ := s.privateKey.Algorithm()
|
||||||
return s.getPublicKeyRaw()
|
token, err := jwt.ParseString(
|
||||||
})
|
tokenString,
|
||||||
if err != nil || !token.Valid {
|
jwt.WithValidate(true),
|
||||||
return nil, errors.New("couldn't handle this token")
|
jwt.WithKey(alg, s.privateKey),
|
||||||
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
|
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
|
return token, nil
|
||||||
if !isValid {
|
|
||||||
return nil, errors.New("can't parse claims")
|
|
||||||
}
|
|
||||||
|
|
||||||
return claims, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
|
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
|
||||||
@@ -325,17 +356,18 @@ func (s *JwtService) GetPublicJWKSAsJSON() ([]byte, error) {
|
|||||||
return s.jwksEncoded, nil
|
return s.jwksEncoded, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) getPublicKeyRaw() (any, error) {
|
// GetKeyAlg returns the algorithm of the key
|
||||||
pubKey, err := s.privateKey.PublicKey()
|
func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
|
||||||
if err != nil {
|
if len(s.jwksEncoded) == 0 {
|
||||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
return nil, errors.New("key is not initialized")
|
||||||
}
|
}
|
||||||
var pubKeyRaw any
|
|
||||||
err = jwk.Export(pubKey, &pubKeyRaw)
|
alg, ok := s.privateKey.Algorithm()
|
||||||
if err != nil {
|
if !ok || alg == nil {
|
||||||
return nil, fmt.Errorf("failed to export raw public key: %w", err)
|
return nil, errors.New("failed to retrieve algorithm for key")
|
||||||
}
|
}
|
||||||
return pubKeyRaw, nil
|
|
||||||
|
return alg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
|
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
|
||||||
@@ -438,3 +470,28 @@ func generateRandomKeyID() (string, error) {
|
|||||||
}
|
}
|
||||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetIsAdmin returns the value of the "isAdmin" claim in the token
|
||||||
|
func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||||
|
if !token.Has(IsAdminClaim) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
var isAdmin bool
|
||||||
|
err := token.Get(IsAdminClaim, &isAdmin)
|
||||||
|
return isAdmin, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIsAdmin sets the "isAdmin" claim in the token
|
||||||
|
func SetIsAdmin(token jwt.Token, isAdmin bool) error {
|
||||||
|
// Only set if true
|
||||||
|
if !isAdmin {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return token.Set(IsAdminClaim, isAdmin)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudienceString sets the "aud" claim with a value that is a string, and not an array
|
||||||
|
// This is permitted by RFC 7519, and it's done here for backwards-compatibility
|
||||||
|
func SetAudienceString(token jwt.Token, audience string) error {
|
||||||
|
return token.Set(jwt.AudienceKey, audience)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,13 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,16 +23,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestJwtService_Init(t *testing.T) {
|
func TestJwtService_Init(t *testing.T) {
|
||||||
|
mockConfig := &AppConfigService{
|
||||||
|
DbConfig: &model.AppConfig{
|
||||||
|
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("should generate new key when none exists", func(t *testing.T) {
|
t.Run("should generate new key when none exists", func(t *testing.T) {
|
||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
// Create a mock AppConfigService
|
|
||||||
appConfigService := &AppConfigService{}
|
|
||||||
|
|
||||||
// Initialize the JWT service
|
// Initialize the JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(appConfigService, tempDir)
|
err := service.init(mockConfig, tempDir)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify the private key was set
|
// Verify the private key was set
|
||||||
@@ -62,7 +68,7 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
|
|
||||||
// First create a service to generate a key
|
// First create a service to generate a key
|
||||||
firstService := &JwtService{}
|
firstService := &JwtService{}
|
||||||
err := firstService.init(&AppConfigService{}, tempDir)
|
err := firstService.init(mockConfig, tempDir)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get the key ID of the first service
|
// Get the key ID of the first service
|
||||||
@@ -71,7 +77,7 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
|
|
||||||
// Now create a new service that should load the existing key
|
// Now create a new service that should load the existing key
|
||||||
secondService := &JwtService{}
|
secondService := &JwtService{}
|
||||||
err = secondService.init(&AppConfigService{}, tempDir)
|
err = secondService.init(mockConfig, tempDir)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the loaded key has the same ID as the original
|
// Verify the loaded key has the same ID as the original
|
||||||
@@ -80,33 +86,72 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should load existing JWK for EC keys", func(t *testing.T) {
|
t.Run("should load existing JWK for ECDSA keys", func(t *testing.T) {
|
||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
// Create a new JWK and save it to disk
|
// Create a new JWK and save it to disk
|
||||||
origKeyID := createECKeyJWK(t, tempDir)
|
origKeyID := createECDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
// Now create a new service that should load the existing key
|
// Now create a new service that should load the existing key
|
||||||
svc := &JwtService{}
|
svc := &JwtService{}
|
||||||
err := svc.init(&AppConfigService{}, tempDir)
|
err := svc.init(mockConfig, tempDir)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure loaded key has the right algorithm
|
||||||
|
alg, ok := svc.privateKey.Algorithm()
|
||||||
|
_ = assert.True(t, ok) &&
|
||||||
|
assert.Equal(t, jwa.ES256().String(), alg.String(), "Loaded key has the incorrect algorithm")
|
||||||
|
|
||||||
// Verify the loaded key has the same ID as the original
|
// Verify the loaded key has the same ID as the original
|
||||||
loadedKeyID, ok := svc.privateKey.KeyID()
|
loadedKeyID, ok := svc.privateKey.KeyID()
|
||||||
require.True(t, ok)
|
_ = assert.True(t, ok) &&
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should load existing JWK for EdDSA keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a new JWK and save it to disk
|
||||||
|
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Now create a new service that should load the existing key
|
||||||
|
svc := &JwtService{}
|
||||||
|
err := svc.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure loaded key has the right algorithm and curve
|
||||||
|
alg, ok := svc.privateKey.Algorithm()
|
||||||
|
_ = assert.True(t, ok) &&
|
||||||
|
assert.Equal(t, jwa.EdDSA().String(), alg.String(), "Loaded key has the incorrect algorithm")
|
||||||
|
|
||||||
|
var curve jwa.EllipticCurveAlgorithm
|
||||||
|
err = svc.privateKey.Get("crv", &curve)
|
||||||
|
_ = assert.NoError(t, err, "Failed to get 'crv' claim") &&
|
||||||
|
assert.Equal(t, jwa.Ed25519().String(), curve.String(), "Curve does not match expected value")
|
||||||
|
|
||||||
|
// Verify the loaded key has the same ID as the original
|
||||||
|
loadedKeyID, ok := svc.privateKey.KeyID()
|
||||||
|
_ = assert.True(t, ok) &&
|
||||||
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJwtService_GetPublicJWK(t *testing.T) {
|
func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||||
|
mockConfig := &AppConfigService{
|
||||||
|
DbConfig: &model.AppConfig{
|
||||||
|
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("returns public key when private key is initialized", func(t *testing.T) {
|
t.Run("returns public key when private key is initialized", func(t *testing.T) {
|
||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
// Create a JWT service with initialized key
|
// Create a JWT service with initialized key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(&AppConfigService{}, tempDir)
|
err := service.init(mockConfig, tempDir)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Get the JWK (public key)
|
// Get the JWK (public key)
|
||||||
@@ -136,11 +181,11 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
// Create an ECDSA key and save it as JWK
|
// Create an ECDSA key and save it as JWK
|
||||||
originalKeyID := createECKeyJWK(t, tempDir)
|
originalKeyID := createECDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
// Create a JWT service that loads the ECDSA key
|
// Create a JWT service that loads the ECDSA key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(&AppConfigService{}, tempDir)
|
err := service.init(mockConfig, tempDir)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Get the JWK (public key)
|
// Get the JWK (public key)
|
||||||
@@ -169,6 +214,44 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
|||||||
assert.Equal(t, "ES256", alg.String(), "Algorithm should be ES256")
|
assert.Equal(t, "ES256", alg.String(), "Algorithm should be ES256")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("returns public key when EdDSA private key is initialized", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an EdDSA key and save it as JWK
|
||||||
|
originalKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the EdDSA key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Get the JWK (public key)
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err, "GetPublicJWK should not return an error when private key is initialized")
|
||||||
|
|
||||||
|
// Verify the returned key is valid
|
||||||
|
require.NotNil(t, publicKey, "Public key should not be nil")
|
||||||
|
|
||||||
|
// Validate it's actually a public key
|
||||||
|
isPrivate, err := jwk.IsPrivateKey(publicKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, isPrivate, "Returned key should be a public key")
|
||||||
|
|
||||||
|
// Check that key has required properties
|
||||||
|
keyID, ok := publicKey.KeyID()
|
||||||
|
require.True(t, ok, "Public key should have a key ID")
|
||||||
|
assert.Equal(t, originalKeyID, keyID, "Key ID should match the original key ID")
|
||||||
|
|
||||||
|
// Check that the key type is OKP
|
||||||
|
assert.Equal(t, "OKP", publicKey.KeyType().String(), "Key type should be OKP")
|
||||||
|
|
||||||
|
// Check that the algorithm is EdDSA
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok, "Public key should have an algorithm")
|
||||||
|
assert.Equal(t, "EdDSA", alg.String(), "Algorithm should be EdDSA")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("returns error when private key is not initialized", func(t *testing.T) {
|
t.Run("returns error when private key is not initialized", func(t *testing.T) {
|
||||||
// Create a service with nil private key
|
// Create a service with nil private key
|
||||||
service := &JwtService{
|
service := &JwtService{
|
||||||
@@ -228,15 +311,22 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to verify generated token")
|
require.NoError(t, err, "Failed to verify generated token")
|
||||||
|
|
||||||
// Check the claims
|
// Check the claims
|
||||||
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
|
subject, ok := claims.Subject()
|
||||||
assert.False(t, claims.IsAdmin, "IsAdmin should be false")
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
assert.Contains(t, claims.Audience, "https://test.example.com", "Audience should contain the app URL")
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
isAdmin, err := GetIsAdmin(claims)
|
||||||
|
_ = assert.NoError(t, err, "Failed to get isAdmin claim") &&
|
||||||
|
assert.False(t, isAdmin, "isAdmin should be false")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.EqualValues(t, []string{"https://test.example.com"}, audience, "Audience should contain the app URL")
|
||||||
|
|
||||||
// Check token expiration time is approximately 60 minutes from now
|
// Check token expiration time is approximately 1 hour from now
|
||||||
expectedExp := time.Now().Add(60 * time.Minute)
|
expectedExp := time.Now().Add(1 * time.Hour)
|
||||||
tokenExp := claims.ExpiresAt.Time
|
expiration, ok := claims.Expiration()
|
||||||
timeDiff := expectedExp.Sub(tokenExp).Minutes()
|
assert.True(t, ok, "Expiration not found in token")
|
||||||
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 60 minutes")
|
timeDiff := expectedExp.Sub(expiration).Minutes()
|
||||||
|
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("generates token for admin user", func(t *testing.T) {
|
t.Run("generates token for admin user", func(t *testing.T) {
|
||||||
@@ -263,8 +353,12 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to verify generated token")
|
require.NoError(t, err, "Failed to verify generated token")
|
||||||
|
|
||||||
// Check the IsAdmin claim is true
|
// Check the IsAdmin claim is true
|
||||||
assert.True(t, claims.IsAdmin, "IsAdmin should be true for admin users")
|
isAdmin, err := GetIsAdmin(claims)
|
||||||
assert.Equal(t, adminUser.ID, claims.Subject, "Token subject should match admin ID")
|
_ = assert.NoError(t, err, "Failed to get isAdmin claim") &&
|
||||||
|
assert.True(t, isAdmin, "isAdmin should be true")
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, adminUser.ID, subject, "Token subject should match user ID")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("uses session duration from config", func(t *testing.T) {
|
t.Run("uses session duration from config", func(t *testing.T) {
|
||||||
@@ -296,10 +390,173 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Check token expiration time is approximately 30 minutes from now
|
// Check token expiration time is approximately 30 minutes from now
|
||||||
expectedExp := time.Now().Add(30 * time.Minute)
|
expectedExp := time.Now().Add(30 * time.Minute)
|
||||||
tokenExp := claims.ExpiresAt.Time
|
expiration, ok := claims.Expiration()
|
||||||
timeDiff := expectedExp.Sub(tokenExp).Minutes()
|
assert.True(t, ok, "Expiration not found in token")
|
||||||
|
timeDiff := expectedExp.Sub(expiration).Minutes()
|
||||||
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 30 minutes")
|
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 30 minutes")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("works with Ed25519 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an Ed25519 key and save it as JWK
|
||||||
|
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create a test user
|
||||||
|
user := model.User{
|
||||||
|
Base: model.Base{
|
||||||
|
ID: "eddsauser123",
|
||||||
|
},
|
||||||
|
Email: "eddsauser@example.com",
|
||||||
|
IsAdmin: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateAccessToken(user)
|
||||||
|
require.NoError(t, err, "Failed to generate access token with Ed25519 key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyAccessToken(tokenString)
|
||||||
|
require.NoError(t, err, "Failed to verify generated token with Ed25519 key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
isAdmin, err := GetIsAdmin(claims)
|
||||||
|
_ = assert.NoError(t, err, "Failed to get isAdmin claim") &&
|
||||||
|
assert.True(t, isAdmin, "isAdmin should be true")
|
||||||
|
|
||||||
|
// Verify the key type is OKP
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "OKP", publicKey.KeyType().String(), "Key type should be OKP")
|
||||||
|
|
||||||
|
// Verify the algorithm is EdDSA
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "EdDSA", alg.String(), "Algorithm should be EdDSA")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with P-256 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an ECDSA key and save it as JWK
|
||||||
|
origKeyID := createECDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create a test user
|
||||||
|
user := model.User{
|
||||||
|
Base: model.Base{
|
||||||
|
ID: "ecdsauser123",
|
||||||
|
},
|
||||||
|
Email: "ecdsauser@example.com",
|
||||||
|
IsAdmin: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateAccessToken(user)
|
||||||
|
require.NoError(t, err, "Failed to generate access token with ECDSA key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyAccessToken(tokenString)
|
||||||
|
require.NoError(t, err, "Failed to verify generated token with ECDSA key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
isAdmin, err := GetIsAdmin(claims)
|
||||||
|
_ = assert.NoError(t, err, "Failed to get isAdmin claim") &&
|
||||||
|
assert.True(t, isAdmin, "isAdmin should be true")
|
||||||
|
|
||||||
|
// Verify the key type is EC
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.EC().String(), publicKey.KeyType().String(), "Key type should be EC")
|
||||||
|
|
||||||
|
// Verify the algorithm is ES256
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.ES256().String(), alg.String(), "Algorithm should be ES256")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with RSA-4096 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an RSA-4096 key and save it as JWK
|
||||||
|
origKeyID := createRSA4096KeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create a test user
|
||||||
|
user := model.User{
|
||||||
|
Base: model.Base{
|
||||||
|
ID: "rsauser123",
|
||||||
|
},
|
||||||
|
Email: "rsauser@example.com",
|
||||||
|
IsAdmin: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateAccessToken(user)
|
||||||
|
require.NoError(t, err, "Failed to generate access token with RSA key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyAccessToken(tokenString)
|
||||||
|
require.NoError(t, err, "Failed to verify generated token with RSA key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
isAdmin, err := GetIsAdmin(claims)
|
||||||
|
_ = assert.NoError(t, err, "Failed to get isAdmin claim") &&
|
||||||
|
assert.True(t, isAdmin, "isAdmin should be true")
|
||||||
|
|
||||||
|
// Verify the key type is RSA
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
|
||||||
|
|
||||||
|
// Verify the algorithm is RS256
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateVerifyIdToken(t *testing.T) {
|
func TestGenerateVerifyIdToken(t *testing.T) {
|
||||||
@@ -340,21 +597,83 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
// Verify the token
|
// Verify the token
|
||||||
claims, err := service.VerifyIdToken(tokenString)
|
claims, err := service.VerifyIdToken(tokenString, false)
|
||||||
require.NoError(t, err, "Failed to verify generated ID token")
|
require.NoError(t, err, "Failed to verify generated ID token")
|
||||||
|
|
||||||
// Check the claims
|
// Check the claims
|
||||||
assert.Equal(t, "user123", claims.Subject, "Token subject should match user ID")
|
subject, ok := claims.Subject()
|
||||||
assert.Contains(t, claims.Audience, clientID, "Audience should contain the client ID")
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, claims.Issuer, "Issuer should match app URL")
|
assert.Equal(t, "user123", subject, "Token subject should match user ID")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.EqualValues(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||||
|
issuer, ok := claims.Issuer()
|
||||||
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
|
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
// Check token expiration time is approximately 1 hour from now
|
// Check token expiration time is approximately 1 hour from now
|
||||||
expectedExp := time.Now().Add(1 * time.Hour)
|
expectedExp := time.Now().Add(1 * time.Hour)
|
||||||
tokenExp := claims.ExpiresAt.Time
|
expiration, ok := claims.Expiration()
|
||||||
timeDiff := expectedExp.Sub(tokenExp).Minutes()
|
assert.True(t, ok, "Expiration not found in token")
|
||||||
|
timeDiff := expectedExp.Sub(expiration).Minutes()
|
||||||
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
|
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("can accept expired tokens if told so", func(t *testing.T) {
|
||||||
|
// Create a JWT service
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Create test claims
|
||||||
|
userClaims := map[string]interface{}{
|
||||||
|
"sub": "user123",
|
||||||
|
"name": "Test User",
|
||||||
|
"email": "user@example.com",
|
||||||
|
}
|
||||||
|
const clientID = "test-client-123"
|
||||||
|
|
||||||
|
// Create a token that's already expired
|
||||||
|
token, err := jwt.NewBuilder().
|
||||||
|
Subject(userClaims["sub"].(string)).
|
||||||
|
Issuer(common.EnvConfig.AppURL).
|
||||||
|
Audience([]string{clientID}).
|
||||||
|
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||||
|
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||||
|
Build()
|
||||||
|
require.NoError(t, err, "Failed to build token")
|
||||||
|
|
||||||
|
// Add custom claims
|
||||||
|
for k, v := range userClaims {
|
||||||
|
if k != "sub" { // Already set above
|
||||||
|
err = token.Set(k, v)
|
||||||
|
require.NoError(t, err, "Failed to set claim")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign the token
|
||||||
|
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||||
|
require.NoError(t, err, "Failed to sign token")
|
||||||
|
tokenString := string(signed)
|
||||||
|
|
||||||
|
// Verify the token without allowExpired flag - should fail
|
||||||
|
_, err = service.VerifyIdToken(tokenString, false)
|
||||||
|
require.Error(t, err, "Verification should fail with expired token when not allowing expired tokens")
|
||||||
|
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
|
||||||
|
|
||||||
|
// Verify the token with allowExpired flag - should succeed
|
||||||
|
claims, err := service.VerifyIdToken(tokenString, true)
|
||||||
|
require.NoError(t, err, "Verification should succeed with expired token when allowing expired tokens")
|
||||||
|
|
||||||
|
// Validate the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID")
|
||||||
|
issuer, ok := claims.Issuer()
|
||||||
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
|
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
|
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
@@ -403,9 +722,168 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
common.EnvConfig.AppURL = "https://wrong-issuer.com"
|
common.EnvConfig.AppURL = "https://wrong-issuer.com"
|
||||||
|
|
||||||
// Verify should fail due to issuer mismatch
|
// Verify should fail due to issuer mismatch
|
||||||
_, err = service.VerifyIdToken(tokenString)
|
_, err = service.VerifyIdToken(tokenString, false)
|
||||||
require.Error(t, err, "Verification should fail with incorrect issuer")
|
require.Error(t, err, "Verification should fail with incorrect issuer")
|
||||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
assert.Contains(t, err.Error(), `"iss" not satisfied`, "Error message should indicate token verification failure")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with Ed25519 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an Ed25519 key and save it as JWK
|
||||||
|
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create test claims
|
||||||
|
userClaims := map[string]interface{}{
|
||||||
|
"sub": "eddsauser456",
|
||||||
|
"name": "EdDSA User",
|
||||||
|
"email": "eddsauser@example.com",
|
||||||
|
}
|
||||||
|
const clientID = "eddsa-client-123"
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
|
||||||
|
require.NoError(t, err, "Failed to generate ID token with key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyIdToken(tokenString, false)
|
||||||
|
require.NoError(t, err, "Failed to verify generated ID token with key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID")
|
||||||
|
issuer, ok := claims.Issuer()
|
||||||
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
|
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
|
// Verify the key type is OKP
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.OKP().String(), publicKey.KeyType().String(), "Key type should be OKP")
|
||||||
|
|
||||||
|
// Verify the algorithm is EdDSA
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.EdDSA().String(), alg.String(), "Algorithm should be EdDSA")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with P-256 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an ECDSA key and save it as JWK
|
||||||
|
origKeyID := createECDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create test claims
|
||||||
|
userClaims := map[string]interface{}{
|
||||||
|
"sub": "ecdsauser456",
|
||||||
|
"name": "ECDSA User",
|
||||||
|
"email": "ecdsauser@example.com",
|
||||||
|
}
|
||||||
|
const clientID = "ecdsa-client-123"
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
|
||||||
|
require.NoError(t, err, "Failed to generate ID token with key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyIdToken(tokenString, false)
|
||||||
|
require.NoError(t, err, "Failed to verify generated ID token with key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID")
|
||||||
|
issuer, ok := claims.Issuer()
|
||||||
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
|
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
|
// Verify the key type is EC
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.EC().String(), publicKey.KeyType().String(), "Key type should be EC")
|
||||||
|
|
||||||
|
// Verify the algorithm is ES256
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.ES256().String(), alg.String(), "Algorithm should be ES256")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with RSA-4096 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an RSA-4096 key and save it as JWK
|
||||||
|
origKeyID := createRSA4096KeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create test claims
|
||||||
|
userClaims := map[string]interface{}{
|
||||||
|
"sub": "rsauser456",
|
||||||
|
"name": "RSA User",
|
||||||
|
"email": "rsauser@example.com",
|
||||||
|
}
|
||||||
|
const clientID = "rsa-client-123"
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
|
||||||
|
require.NoError(t, err, "Failed to generate ID token with key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyIdToken(tokenString, false)
|
||||||
|
require.NoError(t, err, "Failed to verify generated ID token with key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, "rsauser456", subject, "Token subject should match user ID")
|
||||||
|
issuer, ok := claims.Issuer()
|
||||||
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
|
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
|
// Verify the key type is RSA
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
|
||||||
|
|
||||||
|
// Verify the algorithm is RS256
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -452,14 +930,21 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to verify generated OAuth access token")
|
require.NoError(t, err, "Failed to verify generated OAuth access token")
|
||||||
|
|
||||||
// Check the claims
|
// Check the claims
|
||||||
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
|
subject, ok := claims.Subject()
|
||||||
assert.Contains(t, claims.Audience, clientID, "Audience should contain the client ID")
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, claims.Issuer, "Issuer should match app URL")
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.EqualValues(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||||
|
issuer, ok := claims.Issuer()
|
||||||
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
|
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
// Check token expiration time is approximately 1 hour from now
|
// Check token expiration time is approximately 1 hour from now
|
||||||
expectedExp := time.Now().Add(1 * time.Hour)
|
expectedExp := time.Now().Add(1 * time.Hour)
|
||||||
tokenExp := claims.ExpiresAt.Time
|
expiration, ok := claims.Expiration()
|
||||||
timeDiff := expectedExp.Sub(tokenExp).Minutes()
|
assert.True(t, ok, "Expiration not found in token")
|
||||||
|
timeDiff := expectedExp.Sub(expiration).Minutes()
|
||||||
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
|
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -493,7 +978,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
|||||||
// Verify should fail due to expiration
|
// Verify should fail due to expiration
|
||||||
_, err = service.VerifyOauthAccessToken(string(signed))
|
_, err = service.VerifyOauthAccessToken(string(signed))
|
||||||
require.Error(t, err, "Verification should fail with expired token")
|
require.Error(t, err, "Verification should fail with expired token")
|
||||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||||
@@ -521,18 +1006,175 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
|||||||
// Verify with the second service should fail due to different keys
|
// Verify with the second service should fail due to different keys
|
||||||
_, err = service2.VerifyOauthAccessToken(tokenString)
|
_, err = service2.VerifyOauthAccessToken(tokenString)
|
||||||
require.Error(t, err, "Verification should fail with invalid signature")
|
require.Error(t, err, "Verification should fail with invalid signature")
|
||||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with Ed25519 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an Ed25519 key and save it as JWK
|
||||||
|
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create a test user
|
||||||
|
user := model.User{
|
||||||
|
Base: model.Base{
|
||||||
|
ID: "eddsauser789",
|
||||||
|
},
|
||||||
|
Email: "eddsaoauth@example.com",
|
||||||
|
}
|
||||||
|
const clientID = "eddsa-oauth-client"
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||||
|
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||||
|
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.EqualValues(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||||
|
|
||||||
|
// Verify the key type is OKP
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.OKP().String(), publicKey.KeyType().String(), "Key type should be OKP")
|
||||||
|
|
||||||
|
// Verify the algorithm is EdDSA
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.EdDSA().String(), alg.String(), "Algorithm should be EdDSA")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with ECDSA keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an ECDSA key and save it as JWK
|
||||||
|
origKeyID := createECDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create a test user
|
||||||
|
user := model.User{
|
||||||
|
Base: model.Base{
|
||||||
|
ID: "ecdsauser789",
|
||||||
|
},
|
||||||
|
Email: "ecdsaoauth@example.com",
|
||||||
|
}
|
||||||
|
const clientID = "ecdsa-oauth-client"
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||||
|
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||||
|
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.EqualValues(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||||
|
|
||||||
|
// Verify the key type is EC
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.EC().String(), publicKey.KeyType().String(), "Key type should be EC")
|
||||||
|
|
||||||
|
// Verify the algorithm is ES256
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.ES256().String(), alg.String(), "Algorithm should be ES256")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works with RSA-4096 keys", func(t *testing.T) {
|
||||||
|
// Create a temporary directory for the test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create an RSA-4096 key and save it as JWK
|
||||||
|
origKeyID := createRSA4096KeyJWK(t, tempDir)
|
||||||
|
|
||||||
|
// Create a JWT service that loads the key
|
||||||
|
service := &JwtService{}
|
||||||
|
err := service.init(mockConfig, tempDir)
|
||||||
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
|
// Verify it loaded the right key
|
||||||
|
loadedKeyID, ok := service.privateKey.KeyID()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||||
|
|
||||||
|
// Create a test user
|
||||||
|
user := model.User{
|
||||||
|
Base: model.Base{
|
||||||
|
ID: "rsauser789",
|
||||||
|
},
|
||||||
|
Email: "rsaoauth@example.com",
|
||||||
|
}
|
||||||
|
const clientID = "rsa-oauth-client"
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||||
|
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||||
|
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||||
|
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, user.ID, subject, "Token subject should match user ID")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.EqualValues(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||||
|
|
||||||
|
// Verify the key type is RSA
|
||||||
|
publicKey, err := service.GetPublicJWK()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
|
||||||
|
|
||||||
|
// Verify the algorithm is RS256
|
||||||
|
alg, ok := publicKey.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func createECKeyJWK(t *testing.T, path string) string {
|
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Generate a new P-256 ECDSA key
|
|
||||||
privateKeyRaw, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
require.NoError(t, err, "Failed to generate ECDSA key")
|
|
||||||
|
|
||||||
// Import as JWK and save to disk
|
|
||||||
privateKey, err := importRawKey(privateKeyRaw)
|
privateKey, err := importRawKey(privateKeyRaw)
|
||||||
require.NoError(t, err, "Failed to import private key")
|
require.NoError(t, err, "Failed to import private key")
|
||||||
|
|
||||||
@@ -544,3 +1186,47 @@ func createECKeyJWK(t *testing.T, path string) string {
|
|||||||
|
|
||||||
return kid
|
return kid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Because generating a RSA-406 key isn't immediate, we pre-compute one
|
||||||
|
var (
|
||||||
|
rsaKeyPrecomputed *rsa.PrivateKey
|
||||||
|
rsaKeyPrecomputeOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
func createRSA4096KeyJWK(t *testing.T, path string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
rsaKeyPrecomputeOnce.Do(func() {
|
||||||
|
var err error
|
||||||
|
rsaKeyPrecomputed, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
panic("failed to precompute RSA key: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Import as JWK and save to disk
|
||||||
|
return importKey(t, rsaKeyPrecomputed, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createECDSAKeyJWK(t *testing.T, path string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate a new P-256 ECDSA key
|
||||||
|
privateKeyRaw, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err, "Failed to generate ECDSA key")
|
||||||
|
|
||||||
|
// Import as JWK and save to disk
|
||||||
|
return importKey(t, privateKeyRaw, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create an Ed25519 key and save it as JWK
|
||||||
|
func createEdDSAKeyJWK(t *testing.T, path string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate a new Ed25519 key pair
|
||||||
|
_, privateKeyRaw, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err, "Failed to generate Ed25519 key")
|
||||||
|
|
||||||
|
// Import as JWK and save to disk
|
||||||
|
return importKey(t, privateKeyRaw, path)
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,12 +32,12 @@ func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *LdapService) createClient() (*ldap.Conn, error) {
|
func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||||
if s.appConfigService.DbConfig.LdapEnabled.Value != "true" {
|
if !s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
|
||||||
return nil, fmt.Errorf("LDAP is not enabled")
|
return nil, fmt.Errorf("LDAP is not enabled")
|
||||||
}
|
}
|
||||||
// Setup LDAP connection
|
// Setup LDAP connection
|
||||||
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
|
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
|
||||||
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true"
|
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
|
||||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
|
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
||||||
|
|||||||
@@ -461,7 +461,7 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
|
|||||||
|
|
||||||
if strings.Contains(scope, "email") {
|
if strings.Contains(scope, "email") {
|
||||||
claims["email"] = user.Email
|
claims["email"] = user.Email
|
||||||
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.Value == "true"
|
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue()
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(scope, "groups") {
|
if strings.Contains(scope, "groups") {
|
||||||
@@ -547,21 +547,24 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the ID token hint is provided, verify the ID token
|
// If the ID token hint is provided, verify the ID token
|
||||||
claims, err := s.jwtService.VerifyIdToken(input.IdTokenHint)
|
// Here we also accept expired ID tokens, which are fine per spec
|
||||||
|
token, err := s.jwtService.VerifyIdToken(input.IdTokenHint, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", &common.TokenInvalidError{}
|
return "", &common.TokenInvalidError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request
|
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request
|
||||||
if input.ClientId != "" && claims.Audience[0] != input.ClientId {
|
clientID, ok := token.Audience()
|
||||||
|
if !ok || len(clientID) == 0 {
|
||||||
|
return "", &common.TokenInvalidError{}
|
||||||
|
}
|
||||||
|
if input.ClientId != "" && clientID[0] != input.ClientId {
|
||||||
return "", &common.OidcClientIdNotMatchingError{}
|
return "", &common.OidcClientIdNotMatchingError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
clientId := claims.Audience[0]
|
|
||||||
|
|
||||||
// Check if the user has authorized the client before
|
// Check if the user has authorized the client before
|
||||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||||
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientId, userID).Error; err != nil {
|
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).Error; err != nil {
|
||||||
return "", &common.OidcMissingAuthorizationError{}
|
return "", &common.OidcMissingAuthorizationError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func (s *UserGroupService) Delete(id string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
|
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
|
||||||
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
|
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
|
||||||
return &common.LdapUserGroupUpdateError{}
|
return &common.LdapUserGroupUpdateError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Disallow updating the group if it is an LDAP group and LDAP is enabled
|
// Disallow updating the group if it is an LDAP group and LDAP is enabled
|
||||||
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
|
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
|
||||||
return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
|
return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gorm.io/gorm"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PaginationResponse struct {
|
type PaginationResponse struct {
|
||||||
@@ -30,7 +32,7 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
|
|||||||
capitalizedSortColumn := CapitalizeFirstLetter(sort.Column)
|
capitalizedSortColumn := CapitalizeFirstLetter(sort.Column)
|
||||||
|
|
||||||
sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn)
|
sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn)
|
||||||
isSortable := sortField.Tag.Get("sortable") == "true"
|
isSortable, _ := strconv.ParseBool(sortField.Tag.Get("sortable"))
|
||||||
isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc"
|
isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc"
|
||||||
|
|
||||||
if sortFieldFound && isSortable && isValidSortOrder {
|
if sortFieldFound && isSortable && isValidSortOrder {
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ test('End session without id token hint shows confirmation page', async ({ page
|
|||||||
|
|
||||||
test('End session with id token hint redirects to callback URL', async ({ page }) => {
|
test('End session with id token hint redirects to callback URL', async ({ page }) => {
|
||||||
const client = oidcClients.nextcloud;
|
const client = oidcClients.nextcloud;
|
||||||
|
// Note: this token has expired, but it should be accepted by the logout endpoint anyways, per spec
|
||||||
const idToken =
|
const idToken =
|
||||||
'eyJhbGciOiJSUzI1NiIsImtpZCI6Ijh1SER3M002cmY4IiwidHlwIjoiSldUIn0.eyJhdWQiOiIzNjU0YTc0Ni0zNWQ0LTQzMjEtYWM2MS0wYmRjZmYyYjQwNTUiLCJlbWFpbCI6InRpbS5jb29rQHRlc3QuY29tIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsImV4cCI6MTY5MDAwMDAwMSwiZmFtaWx5X25hbWUiOiJUaW0iLCJnaXZlbl9uYW1lIjoiQ29vayIsImlhdCI6MTY5MDAwMDAwMCwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdCIsIm5hbWUiOiJUaW0gQ29vayIsIm5vbmNlIjoib1cxQTFPNzhHUTE1RDczT3NIRXg3V1FLajdacXZITFp1XzM3bWRYSXFBUSIsInN1YiI6IjRiODlkYzItNjJmYi00NmJmLTlmNWYtYzM0ZjRlYWZlOTNlIn0.ruYCyjA2BNjROpmLGPNHrhgUNLnpJMEuncvjDYVuv1dAZwvOPfG-Rn-OseAgJDJbV7wJ0qf6ZmBkGWiifwc_B9h--fgd4Vby9fefj0MiHbSDgQyaU5UmpvJU8OlvM-TueD6ICJL0NeT3DwoW5xpIWaHtt3JqJIdP__Q-lTONL2Zokq50kWm0IO-bIw2QrQviSfHNpv8A5rk1RTzpXCPXYNB-eJbm3oBqYQWzerD9HaNrSvrKA7mKG8Te1mI9aMirPpG9FvcAU-I3lY8ky1hJZDu42jHpVEUdWPAmUZPZafoX8iYtlPfkoklDnHj_cdg4aZBGN5bfjM6xf1Oe_rLDWg';
|
'eyJhbGciOiJSUzI1NiIsImtpZCI6Ijh1SER3M002cmY4IiwidHlwIjoiSldUIn0.eyJhdWQiOiIzNjU0YTc0Ni0zNWQ0LTQzMjEtYWM2MS0wYmRjZmYyYjQwNTUiLCJlbWFpbCI6InRpbS5jb29rQHRlc3QuY29tIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsImV4cCI6MTY5MDAwMDAwMSwiZmFtaWx5X25hbWUiOiJUaW0iLCJnaXZlbl9uYW1lIjoiQ29vayIsImlhdCI6MTY5MDAwMDAwMCwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdCIsIm5hbWUiOiJUaW0gQ29vayIsIm5vbmNlIjoib1cxQTFPNzhHUTE1RDczT3NIRXg3V1FLajdacXZITFp1XzM3bWRYSXFBUSIsInN1YiI6IjRiODlkYzItNjJmYi00NmJmLTlmNWYtYzM0ZjRlYWZlOTNlIn0.ruYCyjA2BNjROpmLGPNHrhgUNLnpJMEuncvjDYVuv1dAZwvOPfG-Rn-OseAgJDJbV7wJ0qf6ZmBkGWiifwc_B9h--fgd4Vby9fefj0MiHbSDgQyaU5UmpvJU8OlvM-TueD6ICJL0NeT3DwoW5xpIWaHtt3JqJIdP__Q-lTONL2Zokq50kWm0IO-bIw2QrQviSfHNpv8A5rk1RTzpXCPXYNB-eJbm3oBqYQWzerD9HaNrSvrKA7mKG8Te1mI9aMirPpG9FvcAU-I3lY8ky1hJZDu42jHpVEUdWPAmUZPZafoX8iYtlPfkoklDnHj_cdg4aZBGN5bfjM6xf1Oe_rLDWg';
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user