mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-09 14:53:00 +03:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
939601b6a4 | ||
|
|
b9daa5d757 | ||
|
|
8304065652 | ||
|
|
7bfc3f43a5 | ||
|
|
c056089c60 | ||
|
|
3350398abc | ||
|
|
0b0a6781ff | ||
|
|
735dc70d5f |
12
.github/workflows/e2e-tests.yml
vendored
12
.github/workflows/e2e-tests.yml
vendored
@@ -15,12 +15,13 @@ jobs:
|
||||
node-version: lts/*
|
||||
cache: 'npm'
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
|
||||
- name: Create dummy GeoLite2 City database
|
||||
run: touch ./backend/GeoLite2-City.mmdb
|
||||
|
||||
|
||||
- name: Build Docker Image
|
||||
run: docker build -t stonith404/pocket-id .
|
||||
|
||||
- name: Run Docker Container
|
||||
run: docker run -d --name pocket-id -p 80:80 --env-file .env.test stonith404/pocket-id
|
||||
|
||||
@@ -36,13 +37,10 @@ jobs:
|
||||
working-directory: ./frontend
|
||||
run: npx playwright test
|
||||
|
||||
- name: Get container logs
|
||||
if: always()
|
||||
run: docker logs pocket-id
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: frontend/tests/.output
|
||||
path: frontend/tests/.report
|
||||
include-hidden-files: true
|
||||
retention-days: 15
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -34,5 +34,6 @@ vite.config.ts.timestamp-*
|
||||
# Application specific
|
||||
data
|
||||
/frontend/tests/.auth
|
||||
/frontend/tests/.report
|
||||
pocket-id-backend
|
||||
/backend/GeoLite2-City.mmdb
|
||||
@@ -1,3 +1,12 @@
|
||||
## [](https://github.com/stonith404/pocket-id/compare/v0.11.0...v) (2024-10-28)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add option to disable self-account editing ([8304065](https://github.com/stonith404/pocket-id/commit/83040656525cf7b6c8f2acf416c5f8f3288f3d48))
|
||||
* add validation to custom claim input ([7bfc3f4](https://github.com/stonith404/pocket-id/commit/7bfc3f43a591287c038187ed5e782de6b9dd738b))
|
||||
* custom claims ([#53](https://github.com/stonith404/pocket-id/issues/53)) ([c056089](https://github.com/stonith404/pocket-id/commit/c056089c6043a825aaaaecf0c57454892a108f1d))
|
||||
|
||||
## [](https://github.com/stonith404/pocket-id/compare/v0.10.0...v) (2024-10-25)
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ func newDatabase() (db *gorm.DB) {
|
||||
log.Fatalf("failed to connect to database: %v", err)
|
||||
}
|
||||
sqlDb, err := db.DB()
|
||||
sqlDb.SetMaxOpenConns(1)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get sql.DB: %v", err)
|
||||
}
|
||||
|
||||
@@ -39,11 +39,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
jwtService := service.NewJwtService(appConfigService)
|
||||
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
|
||||
userService := service.NewUserService(db, jwtService)
|
||||
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService)
|
||||
customClaimService := service.NewCustomClaimService(db)
|
||||
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
|
||||
testService := service.NewTestService(db, appConfigService)
|
||||
userGroupService := service.NewUserGroupService(db)
|
||||
|
||||
r.Use(middleware.NewCorsMiddleware().Add())
|
||||
r.Use(middleware.NewErrorHandlerMiddleware().Add())
|
||||
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
|
||||
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))
|
||||
|
||||
@@ -55,10 +57,11 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
apiGroup := r.Group("/api")
|
||||
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService)
|
||||
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
|
||||
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
|
||||
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService, appConfigService)
|
||||
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
|
||||
controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware)
|
||||
controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService)
|
||||
controller.NewCustomClaimController(apiGroup, jwtAuthMiddleware, customClaimService)
|
||||
|
||||
// Add test controller in non-production environments
|
||||
if common.EnvConfig.AppEnv != "production" {
|
||||
|
||||
@@ -1,19 +1,148 @@
|
||||
package common
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrUsernameTaken = errors.New("username is already taken")
|
||||
ErrEmailTaken = errors.New("email is already taken")
|
||||
ErrSetupAlreadyCompleted = errors.New("setup already completed")
|
||||
ErrTokenInvalidOrExpired = errors.New("token is invalid or expired")
|
||||
ErrOidcMissingAuthorization = errors.New("missing authorization")
|
||||
ErrOidcGrantTypeNotSupported = errors.New("grant type not supported")
|
||||
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
|
||||
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
|
||||
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
|
||||
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
|
||||
ErrFileTypeNotSupported = errors.New("file type not supported")
|
||||
ErrInvalidCredentials = errors.New("no user found with provided credentials")
|
||||
ErrNameAlreadyInUse = errors.New("name is already in use")
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AppError interface {
|
||||
error
|
||||
HttpStatusCode() int
|
||||
}
|
||||
|
||||
// Custom error types for various conditions
|
||||
|
||||
type AlreadyInUseError struct {
|
||||
Property string
|
||||
}
|
||||
|
||||
func (e *AlreadyInUseError) Error() string {
|
||||
return fmt.Sprintf("%s is already in use", e.Property)
|
||||
}
|
||||
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type SetupAlreadyCompletedError struct{}
|
||||
|
||||
func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }
|
||||
func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type TokenInvalidOrExpiredError struct{}
|
||||
|
||||
func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" }
|
||||
func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcMissingAuthorizationError struct{}
|
||||
|
||||
func (e *OidcMissingAuthorizationError) Error() string { return "missing authorization" }
|
||||
func (e *OidcMissingAuthorizationError) HttpStatusCode() int { return http.StatusForbidden }
|
||||
|
||||
type OidcGrantTypeNotSupportedError struct{}
|
||||
|
||||
func (e *OidcGrantTypeNotSupportedError) Error() string { return "grant type not supported" }
|
||||
func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcMissingClientCredentialsError struct{}
|
||||
|
||||
func (e *OidcMissingClientCredentialsError) Error() string { return "client id or secret not provided" }
|
||||
func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcClientSecretInvalidError struct{}
|
||||
|
||||
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
|
||||
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcInvalidAuthorizationCodeError struct{}
|
||||
|
||||
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
|
||||
func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcInvalidCallbackURLError struct{}
|
||||
|
||||
func (e *OidcInvalidCallbackURLError) Error() string { return "invalid callback URL" }
|
||||
func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type FileTypeNotSupportedError struct{}
|
||||
|
||||
func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" }
|
||||
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type InvalidCredentialsError struct{}
|
||||
|
||||
func (e *InvalidCredentialsError) Error() string { return "no user found with provided credentials" }
|
||||
func (e *InvalidCredentialsError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type FileTooLargeError struct {
|
||||
MaxSize string
|
||||
}
|
||||
|
||||
func (e *FileTooLargeError) Error() string {
|
||||
return fmt.Sprintf("The file can't be larger than %s", e.MaxSize)
|
||||
}
|
||||
func (e *FileTooLargeError) HttpStatusCode() int { return http.StatusRequestEntityTooLarge }
|
||||
|
||||
type NotSignedInError struct{}
|
||||
|
||||
func (e *NotSignedInError) Error() string { return "You are not signed in" }
|
||||
func (e *NotSignedInError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type MissingPermissionError struct{}
|
||||
|
||||
func (e *MissingPermissionError) Error() string {
|
||||
return "You don't have permission to perform this action"
|
||||
}
|
||||
func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbidden }
|
||||
|
||||
type TooManyRequestsError struct{}
|
||||
|
||||
func (e *TooManyRequestsError) Error() string {
|
||||
return "Too many requests. Please wait a while before trying again."
|
||||
}
|
||||
func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests }
|
||||
|
||||
type ClientIdOrSecretNotProvidedError struct{}
|
||||
|
||||
func (e *ClientIdOrSecretNotProvidedError) Error() string {
|
||||
return "Client id and secret not provided"
|
||||
}
|
||||
func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type WrongFileTypeError struct {
|
||||
ExpectedFileType string
|
||||
}
|
||||
|
||||
func (e *WrongFileTypeError) Error() string {
|
||||
return fmt.Sprintf("File must be of type %s", e.ExpectedFileType)
|
||||
}
|
||||
func (e *WrongFileTypeError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type MissingSessionIdError struct{}
|
||||
|
||||
func (e *MissingSessionIdError) Error() string {
|
||||
return "Missing session id"
|
||||
}
|
||||
func (e *MissingSessionIdError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type ReservedClaimError struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *ReservedClaimError) Error() string {
|
||||
return fmt.Sprintf("Claim %s is reserved and can't be used", e.Key)
|
||||
}
|
||||
func (e *ReservedClaimError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type DuplicateClaimError struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *DuplicateClaimError) Error() string {
|
||||
return fmt.Sprintf("Claim %s is already defined", e.Key)
|
||||
}
|
||||
func (e *DuplicateClaimError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type AccountEditNotAllowedError struct{}
|
||||
|
||||
func (e *AccountEditNotAllowedError) Error() string {
|
||||
return "You are not allowed to edit your account"
|
||||
}
|
||||
func (e *AccountEditNotAllowedError) HttpStatusCode() int { return http.StatusForbidden }
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
@@ -39,13 +38,13 @@ type AppConfigController struct {
|
||||
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
|
||||
configuration, err := acc.appConfigService.ListAppConfig(false)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var configVariablesDto []dto.PublicAppConfigVariableDto
|
||||
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,13 +54,13 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
|
||||
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
|
||||
configuration, err := acc.appConfigService.ListAppConfig(true)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var configVariablesDto []dto.AppConfigVariableDto
|
||||
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -71,19 +70,19 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
|
||||
func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
|
||||
var input dto.AppConfigUpdateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var configVariablesDto []dto.AppConfigVariableDto
|
||||
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -136,13 +135,13 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
||||
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if fileType != "ico" {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico")
|
||||
c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
|
||||
return
|
||||
}
|
||||
acc.updateImage(c, "favicon", "ico")
|
||||
@@ -164,17 +163,13 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType
|
||||
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.AuditLogService, jwtAuthMiddleware *middleware.JwtAuthMiddleware) {
|
||||
@@ -31,7 +30,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
// Fetch audit logs for the user
|
||||
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, page, pageSize)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,7 +38,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
var logsDtos []dto.AuditLogDto
|
||||
err = dto.MapStructList(logs, &logsDtos)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
78
backend/internal/controller/custom_claim_controller.go
Normal file
78
backend/internal/controller/custom_claim_controller.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func NewCustomClaimController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, customClaimService *service.CustomClaimService) {
|
||||
wkc := &CustomClaimController{customClaimService: customClaimService}
|
||||
group.GET("/custom-claims/suggestions", jwtAuthMiddleware.Add(true), wkc.getSuggestionsHandler)
|
||||
group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserHandler)
|
||||
group.PUT("/custom-claims/user-group/:userGroupId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserGroupHandler)
|
||||
}
|
||||
|
||||
type CustomClaimController struct {
|
||||
customClaimService *service.CustomClaimService
|
||||
}
|
||||
|
||||
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
|
||||
claims, err := ccc.customClaimService.GetSuggestions()
|
||||
if err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, claims)
|
||||
}
|
||||
|
||||
func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
|
||||
var input []dto.CustomClaimCreateDto
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.Param("userId")
|
||||
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input)
|
||||
if err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var customClaimsDto []dto.CustomClaimDto
|
||||
if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, customClaimsDto)
|
||||
}
|
||||
|
||||
func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
|
||||
var input []dto.CustomClaimCreateDto
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.Param("userGroupId")
|
||||
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userId, input)
|
||||
if err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var customClaimsDto []dto.CustomClaimDto
|
||||
if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, customClaimsDto)
|
||||
}
|
||||
@@ -1,13 +1,11 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -42,19 +40,13 @@ type OidcController struct {
|
||||
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
var input dto.AuthorizeOidcClientRequestDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcMissingAuthorization) {
|
||||
utils.CustomControllerError(c, http.StatusForbidden, err.Error())
|
||||
} else if errors.Is(err, common.ErrOidcInvalidCallbackURL) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -69,17 +61,13 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
||||
var input dto.AuthorizeOidcClientRequestDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcInvalidCallbackURL) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -95,7 +83,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
var input dto.OidcIdTokenDto
|
||||
|
||||
if err := c.ShouldBind(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -107,21 +95,14 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
var ok bool
|
||||
clientID, clientSecret, ok = c.Request.BasicAuth()
|
||||
if !ok {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided")
|
||||
c.Error(&common.ClientIdOrSecretNotProvidedError{})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
|
||||
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
|
||||
errors.Is(err, common.ErrOidcClientSecretInvalid) ||
|
||||
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -132,14 +113,14 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
|
||||
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
|
||||
if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
userID := jwtClaims.Subject
|
||||
clientId := jwtClaims.Audience[0]
|
||||
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,7 +131,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
clientId := c.Param("id")
|
||||
client, err := oc.oidcService.GetClient(clientId)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,7 +152,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
}
|
||||
|
||||
func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
@@ -181,13 +162,13 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
|
||||
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var clientsDto []dto.OidcClientDto
|
||||
if err := dto.MapStructList(clients, &clientsDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -200,19 +181,19 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
var input dto.OidcClientCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var clientDto dto.OidcClientDto
|
||||
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -222,7 +203,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClient(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -232,19 +213,19 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||
func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
var input dto.OidcClientCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var clientDto dto.OidcClientDto
|
||||
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -254,7 +235,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -264,7 +245,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -275,17 +256,13 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -295,7 +272,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package controller
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -19,17 +18,22 @@ type TestController struct {
|
||||
|
||||
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
if err := tc.TestService.ResetDatabase(); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.ResetApplicationImages(); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SeedDatabase(); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.ResetAppConfig(); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"golang.org/x/time/rate"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService) {
|
||||
func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService, appConfigService *service.AppConfigService) {
|
||||
uc := UserController{
|
||||
UserService: userService,
|
||||
UserService: userService,
|
||||
AppConfigService: appConfigService,
|
||||
}
|
||||
|
||||
group.GET("/users", jwtAuthMiddleware.Add(true), uc.listUsersHandler)
|
||||
@@ -33,7 +32,8 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
|
||||
}
|
||||
|
||||
type UserController struct {
|
||||
UserService *service.UserService
|
||||
UserService *service.UserService
|
||||
AppConfigService *service.AppConfigService
|
||||
}
|
||||
|
||||
func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
@@ -43,13 +43,13 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
|
||||
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var usersDto []dto.UserDto
|
||||
if err := dto.MapStructList(users, &usersDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -62,13 +62,13 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
func (uc *UserController) getUserHandler(c *gin.Context) {
|
||||
user, err := uc.UserService.GetUser(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -78,13 +78,13 @@ func (uc *UserController) getUserHandler(c *gin.Context) {
|
||||
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
||||
user, err := uc.UserService.GetUser(c.GetString("userID"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
||||
|
||||
func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
||||
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -103,23 +103,19 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
||||
func (uc *UserController) createUserHandler(c *gin.Context) {
|
||||
var input dto.UserCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := uc.UserService.CreateUser(input)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -131,19 +127,23 @@ func (uc *UserController) updateUserHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
|
||||
if uc.AppConfigService.DbConfig.AllowOwnAccountEdit.Value != "true" {
|
||||
c.Error(&common.AccountEditNotAllowedError{})
|
||||
return
|
||||
}
|
||||
uc.updateUser(c, true)
|
||||
}
|
||||
|
||||
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
var input dto.OneTimeAccessTokenCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -153,17 +153,13 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrTokenInvalidOrExpired) {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -174,17 +170,13 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
||||
user, token, err := uc.UserService.SetupInitialAdmin()
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrSetupAlreadyCompleted) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -195,7 +187,7 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
||||
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||
var input dto.UserCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -208,17 +200,13 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||
|
||||
user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
func NewUserGroupController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, userGroupService *service.UserGroupService) {
|
||||
@@ -37,7 +34,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
|
||||
groups, pagination, err := ugc.UserGroupService.List(searchTerm, page, pageSize)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -45,12 +42,12 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
for i, group := range groups {
|
||||
var groupDto dto.UserGroupDtoWithUserCount
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
groupsDto[i] = groupDto
|
||||
@@ -65,13 +62,13 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
group, err := ugc.UserGroupService.Get(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -81,23 +78,19 @@ func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
var input dto.UserGroupCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := ugc.UserGroupService.Create(input)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrNameAlreadyInUse) {
|
||||
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -107,23 +100,19 @@ func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
var input dto.UserGroupCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := ugc.UserGroupService.Update(c.Param("id"), input)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrNameAlreadyInUse) {
|
||||
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -132,7 +121,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
|
||||
func (ugc *UserGroupController) delete(c *gin.Context) {
|
||||
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -142,19 +131,19 @@ func (ugc *UserGroupController) delete(c *gin.Context) {
|
||||
func (ugc *UserGroupController) updateUsers(c *gin.Context) {
|
||||
var input dto.UserGroupUpdateUsersDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@@ -38,7 +36,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
options, err := wc.webAuthnService.BeginRegistration(userID)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -49,20 +47,20 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
|
||||
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
|
||||
c.Error(&common.MissingSessionIdError{})
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("userID")
|
||||
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var credentialDto dto.WebauthnCredentialDto
|
||||
if err := dto.MapStruct(credential, &credentialDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -72,7 +70,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
||||
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
||||
options, err := wc.webAuthnService.BeginLogin()
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -83,13 +81,13 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
||||
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
|
||||
c.Error(&common.MissingSessionIdError{})
|
||||
return
|
||||
}
|
||||
|
||||
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -97,17 +95,13 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
||||
|
||||
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrInvalidCredentials) {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -119,13 +113,13 @@ func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
credentials, err := wc.webAuthnService.ListCredentials(userID)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var credentialDtos []dto.WebauthnCredentialDto
|
||||
if err := dto.MapStructList(credentials, &credentialDtos); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -138,7 +132,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
|
||||
|
||||
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -151,19 +145,19 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
|
||||
|
||||
var input dto.WebauthnCredentialUpdateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var credentialDto dto.WebauthnCredentialDto
|
||||
if err := dto.MapStruct(credential, &credentialDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -21,7 +20,7 @@ type WellKnownController struct {
|
||||
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
|
||||
jwk, err := wkc.jwtService.GetJWK()
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -12,13 +12,14 @@ type AppConfigVariableDto struct {
|
||||
}
|
||||
|
||||
type AppConfigUpdateDto struct {
|
||||
AppName string `json:"appName" binding:"required,min=1,max=30"`
|
||||
SessionDuration string `json:"sessionDuration" binding:"required"`
|
||||
EmailsVerified string `json:"emailsVerified" binding:"required"`
|
||||
EmailEnabled string `json:"emailEnabled" binding:"required"`
|
||||
SmtHost string `json:"smtpHost"`
|
||||
SmtpPort string `json:"smtpPort"`
|
||||
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
|
||||
SmtpUser string `json:"smtpUser"`
|
||||
SmtpPassword string `json:"smtpPassword"`
|
||||
AppName string `json:"appName" binding:"required,min=1,max=30"`
|
||||
SessionDuration string `json:"sessionDuration" binding:"required"`
|
||||
EmailsVerified string `json:"emailsVerified" binding:"required"`
|
||||
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
|
||||
EmailEnabled string `json:"emailEnabled" binding:"required"`
|
||||
SmtHost string `json:"smtpHost"`
|
||||
SmtpPort string `json:"smtpPort"`
|
||||
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
|
||||
SmtpUser string `json:"smtpUser"`
|
||||
SmtpPassword string `json:"smtpPassword"`
|
||||
}
|
||||
|
||||
11
backend/internal/dto/custom_claim_dto.go
Normal file
11
backend/internal/dto/custom_claim_dto.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package dto
|
||||
|
||||
type CustomClaimDto struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type CustomClaimCreateDto struct {
|
||||
Key string `json:"key" binding:"required,claimKey"`
|
||||
Value string `json:"value" binding:"required"`
|
||||
}
|
||||
@@ -3,12 +3,13 @@ package dto
|
||||
import "time"
|
||||
|
||||
type UserDto struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email" `
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email" `
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
}
|
||||
|
||||
type UserCreateDto struct {
|
||||
|
||||
@@ -3,19 +3,21 @@ package dto
|
||||
import "time"
|
||||
|
||||
type UserGroupDtoWithUsers struct {
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
Users []UserDto `json:"users"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
Users []UserDto `json:"users"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type UserGroupDtoWithUserCount struct {
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
UserCount int64 `json:"userCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
UserCount int64 `json:"userCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type UserGroupCreateDto struct {
|
||||
|
||||
@@ -29,8 +29,15 @@ var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
||||
}
|
||||
|
||||
var validateUserGroupName validator.Func = func(fl validator.FieldLevel) bool {
|
||||
// [a-z0-9_] : The group name can only contain lowercase letters, numbers, and underscores
|
||||
regex := "^[a-z0-9_]+$"
|
||||
// The string can only contain lowercase letters, numbers, and underscores
|
||||
regex := "^[a-z0-9_]*$"
|
||||
matched, _ := regexp.MatchString(regex, fl.Field().String())
|
||||
return matched
|
||||
}
|
||||
|
||||
var validateClaimKey validator.Func = func(fl validator.FieldLevel) bool {
|
||||
// The string can only contain letters and numbers
|
||||
regex := "^[A-Za-z0-9]*$"
|
||||
matched, _ := regexp.MatchString(regex, fl.Field().String())
|
||||
return matched
|
||||
}
|
||||
@@ -52,4 +59,10 @@ func init() {
|
||||
log.Fatalf("Failed to register custom validation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||
if err := v.RegisterValidation("claimKey", validateClaimKey); err != nil {
|
||||
log.Fatalf("Failed to register custom validation: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,37 +1,67 @@
|
||||
package utils
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
type ErrorHandlerMiddleware struct{}
|
||||
|
||||
func ControllerError(c *gin.Context, err error) {
|
||||
// Check for record not found errors
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
CustomControllerError(c, http.StatusNotFound, "Record not found")
|
||||
return
|
||||
func NewErrorHandlerMiddleware() *ErrorHandlerMiddleware {
|
||||
return &ErrorHandlerMiddleware{}
|
||||
}
|
||||
|
||||
func (m *ErrorHandlerMiddleware) Add() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
for _, err := range c.Errors {
|
||||
|
||||
// Check for record not found errors
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
errorResponse(c, http.StatusNotFound, "Record not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Check for validation errors
|
||||
var validationErrors validator.ValidationErrors
|
||||
if errors.As(err, &validationErrors) {
|
||||
message := handleValidationError(validationErrors)
|
||||
errorResponse(c, http.StatusBadRequest, message)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for slice validation errors
|
||||
var sliceValidationErrors binding.SliceValidationError
|
||||
if errors.As(err, &sliceValidationErrors) {
|
||||
if errors.As(sliceValidationErrors[0], &validationErrors) {
|
||||
message := handleValidationError(validationErrors)
|
||||
errorResponse(c, http.StatusBadRequest, message)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var appErr common.AppError
|
||||
if errors.As(err, &appErr) {
|
||||
errorResponse(c, appErr.HttpStatusCode(), appErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for validation errors
|
||||
var validationErrors validator.ValidationErrors
|
||||
if errors.As(err, &validationErrors) {
|
||||
message := handleValidationError(validationErrors)
|
||||
CustomControllerError(c, http.StatusBadRequest, message)
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
log.Println(err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
||||
func errorResponse(c *gin.Context, statusCode int, message string) {
|
||||
// Capitalize the first letter of the message
|
||||
message = strings.ToUpper(message[:1]) + message[1:]
|
||||
c.JSON(statusCode, gin.H{"error": message})
|
||||
}
|
||||
|
||||
func handleValidationError(validationErrors validator.ValidationErrors) string {
|
||||
@@ -67,9 +97,3 @@ func handleValidationError(validationErrors validator.ValidationErrors) string {
|
||||
|
||||
return combinedErrors
|
||||
}
|
||||
|
||||
func CustomControllerError(c *gin.Context, statusCode int, message string) {
|
||||
// Capitalize the first letter of the message
|
||||
message = strings.ToUpper(message[:1]) + message[1:]
|
||||
c.JSON(statusCode, gin.H{"error": message})
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package middleware
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -17,8 +17,8 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
|
||||
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
||||
utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
|
||||
c.Abort()
|
||||
err = &common.FileTooLargeError{MaxSize: formatFileSize(maxSize)}
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
|
||||
@@ -2,9 +2,8 @@ package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -29,8 +28,7 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
|
||||
c.Next()
|
||||
return
|
||||
} else {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Abort()
|
||||
c.Error(&common.NotSignedInError{})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -40,14 +38,14 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
|
||||
c.Next()
|
||||
return
|
||||
} else if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Error(&common.NotSignedInError{})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is an admin
|
||||
if adminOnly && !claims.IsAdmin {
|
||||
utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource")
|
||||
c.Error(&common.MissingPermissionError{})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package middleware
|
||||
|
||||
import (
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -33,8 +31,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
|
||||
|
||||
limiter := getLimiter(ip, limit, burst)
|
||||
if !limiter.Allow() {
|
||||
utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
|
||||
c.Abort()
|
||||
c.Error(&common.TooManyRequestsError{})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
package model
|
||||
|
||||
type AppConfigVariable struct {
|
||||
Key string `gorm:"primaryKey;not null"`
|
||||
Type string
|
||||
IsPublic bool
|
||||
IsInternal bool
|
||||
Value string
|
||||
Key string `gorm:"primaryKey;not null"`
|
||||
Type string
|
||||
IsPublic bool
|
||||
IsInternal bool
|
||||
Value string
|
||||
DefaultValue string
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
AppName AppConfigVariable
|
||||
SessionDuration AppConfigVariable
|
||||
EmailsVerified AppConfigVariable
|
||||
AllowOwnAccountEdit AppConfigVariable
|
||||
|
||||
BackgroundImageType AppConfigVariable
|
||||
LogoLightImageType AppConfigVariable
|
||||
LogoDarkImageType AppConfigVariable
|
||||
SessionDuration AppConfigVariable
|
||||
EmailsVerified AppConfigVariable
|
||||
|
||||
EmailEnabled AppConfigVariable
|
||||
SmtpHost AppConfigVariable
|
||||
|
||||
11
backend/internal/model/custom_claim.go
Normal file
11
backend/internal/model/custom_claim.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package model
|
||||
|
||||
type CustomClaim struct {
|
||||
Base
|
||||
|
||||
Key string
|
||||
Value string
|
||||
|
||||
UserID *string
|
||||
UserGroupID *string
|
||||
}
|
||||
@@ -15,8 +15,9 @@ type User struct {
|
||||
LastName string
|
||||
IsAdmin bool
|
||||
|
||||
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
|
||||
Credentials []WebauthnCredential
|
||||
CustomClaims []CustomClaim
|
||||
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
|
||||
Credentials []WebauthnCredential
|
||||
}
|
||||
|
||||
func (u User) WebAuthnID() []byte { return []byte(u.ID) }
|
||||
|
||||
@@ -5,4 +5,5 @@ type UserGroup struct {
|
||||
FriendlyName string
|
||||
Name string `gorm:"unique"`
|
||||
Users []User `gorm:"many2many:user_groups_users;"`
|
||||
CustomClaims []CustomClaim
|
||||
}
|
||||
|
||||
@@ -31,43 +31,49 @@ func NewAppConfigService(db *gorm.DB) *AppConfigService {
|
||||
|
||||
var defaultDbConfig = model.AppConfig{
|
||||
AppName: model.AppConfigVariable{
|
||||
Key: "appName",
|
||||
Type: "string",
|
||||
IsPublic: true,
|
||||
Value: "Pocket ID",
|
||||
Key: "appName",
|
||||
Type: "string",
|
||||
IsPublic: true,
|
||||
DefaultValue: "Pocket ID",
|
||||
},
|
||||
SessionDuration: model.AppConfigVariable{
|
||||
Key: "sessionDuration",
|
||||
Type: "number",
|
||||
Value: "60",
|
||||
Key: "sessionDuration",
|
||||
Type: "number",
|
||||
DefaultValue: "60",
|
||||
},
|
||||
EmailsVerified: model.AppConfigVariable{
|
||||
Key: "emailsVerified",
|
||||
Type: "bool",
|
||||
Value: "false",
|
||||
Key: "emailsVerified",
|
||||
Type: "bool",
|
||||
DefaultValue: "false",
|
||||
},
|
||||
AllowOwnAccountEdit: model.AppConfigVariable{
|
||||
Key: "allowOwnAccountEdit",
|
||||
Type: "bool",
|
||||
IsPublic: true,
|
||||
DefaultValue: "true",
|
||||
},
|
||||
BackgroundImageType: model.AppConfigVariable{
|
||||
Key: "backgroundImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "jpg",
|
||||
Key: "backgroundImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
DefaultValue: "jpg",
|
||||
},
|
||||
LogoLightImageType: model.AppConfigVariable{
|
||||
Key: "logoLightImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "svg",
|
||||
Key: "logoLightImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
DefaultValue: "svg",
|
||||
},
|
||||
LogoDarkImageType: model.AppConfigVariable{
|
||||
Key: "logoDarkImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "svg",
|
||||
Key: "logoDarkImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
DefaultValue: "svg",
|
||||
},
|
||||
EmailEnabled: model.AppConfigVariable{
|
||||
Key: "emailEnabled",
|
||||
Type: "bool",
|
||||
Value: "false",
|
||||
Key: "emailEnabled",
|
||||
Type: "bool",
|
||||
DefaultValue: "false",
|
||||
},
|
||||
SmtpHost: model.AppConfigVariable{
|
||||
Key: "smtpHost",
|
||||
@@ -120,7 +126,7 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
|
||||
|
||||
tx.Commit()
|
||||
|
||||
if err := s.loadDbConfigFromDb(); err != nil {
|
||||
if err := s.LoadDbConfigFromDb(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -134,7 +140,7 @@ func (s *AppConfigService) UpdateImageType(imageName string, fileType string) er
|
||||
return err
|
||||
}
|
||||
|
||||
return s.loadDbConfigFromDb()
|
||||
return s.LoadDbConfigFromDb()
|
||||
}
|
||||
|
||||
func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) {
|
||||
@@ -151,6 +157,13 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the value to the default value if it is empty
|
||||
for i := range configuration {
|
||||
if configuration[i].Value == "" && configuration[i].DefaultValue != "" {
|
||||
configuration[i].Value = configuration[i].DefaultValue
|
||||
}
|
||||
}
|
||||
|
||||
return configuration, nil
|
||||
}
|
||||
|
||||
@@ -158,7 +171,7 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
|
||||
fileType := utils.GetFileExtension(uploadedFile.Filename)
|
||||
mimeType := utils.GetImageMimeType(fileType)
|
||||
if mimeType == "" {
|
||||
return common.ErrFileTypeNotSupported
|
||||
return &common.FileTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
// Delete the old image if it has a different file type
|
||||
@@ -206,10 +219,11 @@ func (s *AppConfigService) InitDbConfig() error {
|
||||
}
|
||||
|
||||
// Update existing configuration if it differs from the default
|
||||
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
|
||||
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
|
||||
storedConfigVar.Type = defaultConfigVar.Type
|
||||
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
|
||||
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
|
||||
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue
|
||||
if err := s.db.Save(&storedConfigVar).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -229,10 +243,11 @@ func (s *AppConfigService) InitDbConfig() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.loadDbConfigFromDb()
|
||||
return s.LoadDbConfigFromDb()
|
||||
}
|
||||
|
||||
func (s *AppConfigService) loadDbConfigFromDb() error {
|
||||
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
|
||||
func (s *AppConfigService) LoadDbConfigFromDb() error {
|
||||
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
|
||||
|
||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||
@@ -243,6 +258,10 @@ func (s *AppConfigService) loadDbConfigFromDb() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
|
||||
storedConfigVar.Value = storedConfigVar.DefaultValue
|
||||
}
|
||||
|
||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||
}
|
||||
|
||||
|
||||
197
backend/internal/service/custom_claim_service.go
Normal file
197
backend/internal/service/custom_claim_service.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Reserved claims
|
||||
var reservedClaims = map[string]struct{}{
|
||||
"given_name": {},
|
||||
"family_name": {},
|
||||
"name": {},
|
||||
"email": {},
|
||||
"preferred_username": {},
|
||||
"groups": {},
|
||||
"sub": {},
|
||||
"iss": {},
|
||||
"aud": {},
|
||||
"exp": {},
|
||||
"iat": {},
|
||||
"auth_time": {},
|
||||
"nonce": {},
|
||||
"acr": {},
|
||||
"amr": {},
|
||||
"azp": {},
|
||||
"nbf": {},
|
||||
"jti": {},
|
||||
}
|
||||
|
||||
type CustomClaimService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
|
||||
return &CustomClaimService{db: db}
|
||||
}
|
||||
|
||||
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
|
||||
func isReservedClaim(key string) bool {
|
||||
_, ok := reservedClaims[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// idType is the type of the id used to identify the user or user group
|
||||
type idType string
|
||||
|
||||
const (
|
||||
UserID idType = "user_id"
|
||||
UserGroupID idType = "user_group_id"
|
||||
)
|
||||
|
||||
// UpdateCustomClaimsForUser updates the custom claims for a user
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(UserID, userID, claims)
|
||||
}
|
||||
|
||||
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(UserGroupID, userGroupID, claims)
|
||||
}
|
||||
|
||||
// updateCustomClaims updates the custom claims for a user or user group
|
||||
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
// Check for duplicate keys in the claims slice
|
||||
seenKeys := make(map[string]bool)
|
||||
for _, claim := range claims {
|
||||
if seenKeys[claim.Key] {
|
||||
return nil, &common.DuplicateClaimError{Key: claim.Key}
|
||||
}
|
||||
seenKeys[claim.Key] = true
|
||||
}
|
||||
|
||||
var existingClaims []model.CustomClaim
|
||||
err := s.db.Where(string(idType), value).Find(&existingClaims).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete claims that are not in the new list
|
||||
for _, existingClaim := range existingClaims {
|
||||
found := false
|
||||
for _, claim := range claims {
|
||||
if claim.Key == existingClaim.Key {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err = s.db.Delete(&existingClaim).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add or update claims
|
||||
for _, claim := range claims {
|
||||
if isReservedClaim(claim.Key) {
|
||||
return nil, &common.ReservedClaimError{Key: claim.Key}
|
||||
}
|
||||
customClaim := model.CustomClaim{
|
||||
Key: claim.Key,
|
||||
Value: claim.Value,
|
||||
}
|
||||
|
||||
if idType == UserID {
|
||||
customClaim.UserID = &value
|
||||
} else if idType == UserGroupID {
|
||||
customClaim.UserGroupID = &value
|
||||
}
|
||||
|
||||
// Update the claim if it already exists or create a new one
|
||||
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the updated claims
|
||||
var updatedClaims []model.CustomClaim
|
||||
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updatedClaims, nil
|
||||
}
|
||||
|
||||
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) {
|
||||
var customClaims []model.CustomClaim
|
||||
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error
|
||||
return customClaims, err
|
||||
}
|
||||
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) {
|
||||
var customClaims []model.CustomClaim
|
||||
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error
|
||||
return customClaims, err
|
||||
}
|
||||
|
||||
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
|
||||
// prioritizing the user's claims over user group claims with the same key.
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) {
|
||||
// Get the custom claims of the user
|
||||
customClaims, err := s.GetCustomClaimsForUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store user's claims in a map to prioritize and prevent duplicates
|
||||
claimsMap := make(map[string]model.CustomClaim)
|
||||
for _, claim := range customClaims {
|
||||
claimsMap[claim.Key] = claim
|
||||
}
|
||||
|
||||
// Get all user groups of the user
|
||||
var userGroupsOfUser []model.UserGroup
|
||||
err = s.db.Preload("CustomClaims").
|
||||
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
|
||||
Where("user_groups_users.user_id = ?", userID).
|
||||
Find(&userGroupsOfUser).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add only non-duplicate custom claims from user groups
|
||||
for _, userGroup := range userGroupsOfUser {
|
||||
for _, groupClaim := range userGroup.CustomClaims {
|
||||
// Only add claim if it does not exist in the user's claims
|
||||
if _, exists := claimsMap[groupClaim.Key]; !exists {
|
||||
claimsMap[groupClaim.Key] = groupClaim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the claimsMap back to a slice
|
||||
finalClaims := make([]model.CustomClaim, 0, len(claimsMap))
|
||||
for _, claim := range claimsMap {
|
||||
finalClaims = append(finalClaims, claim)
|
||||
}
|
||||
|
||||
return finalClaims, nil
|
||||
}
|
||||
|
||||
// GetSuggestions returns a list of custom claim keys that have been used before
|
||||
func (s *CustomClaimService) GetSuggestions() ([]string, error) {
|
||||
var customClaimsKeys []string
|
||||
|
||||
err := s.db.Model(&model.CustomClaim{}).
|
||||
Group("key").
|
||||
Order("COUNT(*) DESC").
|
||||
Pluck("key", &customClaimsKeys).Error
|
||||
|
||||
return customClaimsKeys, err
|
||||
}
|
||||
@@ -18,18 +18,20 @@ import (
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
db *gorm.DB
|
||||
jwtService *JwtService
|
||||
appConfigService *AppConfigService
|
||||
auditLogService *AuditLogService
|
||||
db *gorm.DB
|
||||
jwtService *JwtService
|
||||
appConfigService *AppConfigService
|
||||
auditLogService *AuditLogService
|
||||
customClaimService *CustomClaimService
|
||||
}
|
||||
|
||||
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService) *OidcService {
|
||||
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
|
||||
return &OidcService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
appConfigService: appConfigService,
|
||||
auditLogService: auditLogService,
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
appConfigService: appConfigService,
|
||||
auditLogService: auditLogService,
|
||||
customClaimService: customClaimService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +40,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
|
||||
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
|
||||
|
||||
if userAuthorizedOIDCClient.Scope != input.Scope {
|
||||
return "", "", common.ErrOidcMissingAuthorization
|
||||
return "", "", &common.OidcMissingAuthorizationError{}
|
||||
}
|
||||
|
||||
callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
|
||||
@@ -93,11 +95,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
|
||||
|
||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
|
||||
if grantType != "authorization_code" {
|
||||
return "", "", common.ErrOidcGrantTypeNotSupported
|
||||
return "", "", &common.OidcGrantTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", "", common.ErrOidcMissingClientCredentials
|
||||
return "", "", &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
var client model.OidcClient
|
||||
@@ -107,17 +109,17 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return "", "", common.ErrOidcClientSecretInvalid
|
||||
return "", "", &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
||||
if err != nil {
|
||||
return "", "", common.ErrOidcInvalidAuthorizationCode
|
||||
return "", "", &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||
return "", "", common.ErrOidcInvalidAuthorizationCode
|
||||
return "", "", &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
||||
@@ -249,7 +251,7 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
|
||||
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
||||
return common.ErrFileTypeNotSupported
|
||||
return &common.FileTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
|
||||
@@ -334,9 +336,20 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "profile") {
|
||||
// Add profile claims
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, customClaim := range customClaims {
|
||||
claims[customClaim.Key] = customClaim.Value
|
||||
}
|
||||
}
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
@@ -375,5 +388,5 @@ func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackU
|
||||
return inputCallbackURL, nil
|
||||
}
|
||||
|
||||
return "", common.ErrOidcInvalidCallbackURL
|
||||
return "", &common.OidcInvalidCallbackURLError{}
|
||||
}
|
||||
|
||||
@@ -138,8 +138,8 @@ func (s *TestService) SeedDatabase() error {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey1, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKey2, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==")
|
||||
publicKey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -187,17 +187,16 @@ func (s *TestService) ResetDatabase() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete all rows from all tables
|
||||
for _, table := range tables {
|
||||
if err := tx.Exec("DELETE FROM " + table).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.appConfigService.InitDbConfig()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -215,8 +214,23 @@ func (s *TestService) ResetApplicationImages() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TestService) ResetAppConfig() error {
|
||||
// Reseed the config variables
|
||||
if err := s.appConfigService.InitDbConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reset all app config variables to their default values
|
||||
if err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reload the app config from the database after resetting the values
|
||||
return s.appConfigService.LoadDbConfigFromDb()
|
||||
}
|
||||
|
||||
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
||||
func getCborPublicKey(base64PublicKey string) ([]byte, error) {
|
||||
func (s *TestService) getCborPublicKey(base64PublicKey string) ([]byte, error) {
|
||||
decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 key: %w", err)
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewUserGroupService(db *gorm.DB) *UserGroupService {
|
||||
}
|
||||
|
||||
func (s *UserGroupService) List(name string, page int, pageSize int) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
|
||||
query := s.db.Model(&model.UserGroup{})
|
||||
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{})
|
||||
|
||||
if name != "" {
|
||||
query = query.Where("name LIKE ?", "%"+name+"%")
|
||||
@@ -29,7 +29,7 @@ func (s *UserGroupService) List(name string, page int, pageSize int) (groups []m
|
||||
}
|
||||
|
||||
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) {
|
||||
err = s.db.Where("id = ?", id).Preload("Users").First(&group).Error
|
||||
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
|
||||
|
||||
if err := s.db.Preload("Users").Create(&group).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return model.UserGroup{}, common.ErrNameAlreadyInUse
|
||||
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
|
||||
}
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto) (grou
|
||||
|
||||
if err := s.db.Preload("Users").Save(&group).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return model.UserGroup{}, common.ErrNameAlreadyInUse
|
||||
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
|
||||
}
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]mo
|
||||
|
||||
func (s *UserService) GetUser(userID string) (model.User, error) {
|
||||
var user model.User
|
||||
err := s.db.Where("id = ?", userID).First(&user).Error
|
||||
err := s.db.Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
|
||||
return user, err
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, stri
|
||||
var oneTimeAccessToken model.OneTimeAccessToken
|
||||
if err := s.db.Where("token = ? AND expires_at > ?", token, time.Now().Unix()).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return model.User{}, "", common.ErrTokenInvalidOrExpired
|
||||
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
||||
}
|
||||
return model.User{}, "", err
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
if userCount > 1 {
|
||||
return model.User{}, "", common.ErrSetupAlreadyCompleted
|
||||
return model.User{}, "", &common.SetupAlreadyCompletedError{}
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
@@ -149,7 +149,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
}
|
||||
|
||||
if len(user.Credentials) > 0 {
|
||||
return model.User{}, "", common.ErrSetupAlreadyCompleted
|
||||
return model.User{}, "", &common.SetupAlreadyCompletedError{}
|
||||
}
|
||||
|
||||
token, err := s.jwtService.GenerateAccessToken(user)
|
||||
@@ -163,11 +163,11 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
func (s *UserService) checkDuplicatedFields(user model.User) error {
|
||||
var existingUser model.User
|
||||
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
|
||||
return common.ErrEmailTaken
|
||||
return &common.AlreadyInUseError{Property: "email"}
|
||||
}
|
||||
|
||||
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
|
||||
return common.ErrUsernameTaken
|
||||
return &common.AlreadyInUseError{Property: "username"}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE app_config_variables DROP COLUMN default_value;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE app_config_variables ADD COLUMN default_value TEXT;
|
||||
1
backend/migrations/20241028064959_custom_claims.down.sql
Normal file
1
backend/migrations/20241028064959_custom_claims.down.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP TABLE custom_claims;
|
||||
15
backend/migrations/20241028064959_custom_claims.up.sql
Normal file
15
backend/migrations/20241028064959_custom_claims.up.sql
Normal file
@@ -0,0 +1,15 @@
|
||||
CREATE TABLE custom_claims
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
|
||||
user_id TEXT,
|
||||
user_group_id TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE,
|
||||
|
||||
CONSTRAINT custom_claims_unique UNIQUE (key, user_id, user_group_id),
|
||||
CHECK (user_id IS NOT NULL OR user_group_id IS NOT NULL)
|
||||
);
|
||||
4
frontend/package-lock.json
generated
4
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "pocket-id-frontend",
|
||||
"version": "0.9.0",
|
||||
"version": "0.10.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "pocket-id-frontend",
|
||||
"version": "0.9.0",
|
||||
"version": "0.10.0",
|
||||
"dependencies": {
|
||||
"@simplewebauthn/browser": "^10.0.0",
|
||||
"axios": "^1.7.7",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "pocket-id-frontend",
|
||||
"version": "0.11.0",
|
||||
"version": "0.12.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "vite dev --port 3000",
|
||||
|
||||
@@ -12,8 +12,8 @@ export default defineConfig({
|
||||
retries: process.env.CI ? 1 : 0,
|
||||
workers: 1,
|
||||
reporter: process.env.CI
|
||||
? [['html'], ['github']]
|
||||
: [['line'], ['html', { open: 'never', outputFolder: 'tests/.output' }]],
|
||||
? [['html', { outputFolder: 'tests/.report' }], ['github']]
|
||||
: [['line'], ['html', { open: 'never', outputFolder: 'tests/.report' }]],
|
||||
use: {
|
||||
baseURL: 'http://localhost',
|
||||
video: 'retain-on-failure',
|
||||
|
||||
116
frontend/src/lib/components/auto-complete-input.svelte
Normal file
116
frontend/src/lib/components/auto-complete-input.svelte
Normal file
@@ -0,0 +1,116 @@
|
||||
<script lang="ts">
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import * as Popover from '$lib/components/ui/popover/index.js';
|
||||
|
||||
let {
|
||||
value = $bindable(''),
|
||||
placeholder,
|
||||
suggestionLimit = 5,
|
||||
suggestions
|
||||
}: {
|
||||
value: string;
|
||||
placeholder: string;
|
||||
suggestionLimit?: number;
|
||||
suggestions: string[];
|
||||
} = $props();
|
||||
|
||||
let filteredSuggestions: string[] = $state(suggestions.slice(0, suggestionLimit));
|
||||
let selectedIndex = $state(-1);
|
||||
let keyError: string | undefined = $state();
|
||||
|
||||
let isInputFocused = $state(false);
|
||||
|
||||
function handleSuggestionClick(suggestion: (typeof suggestions)[0]) {
|
||||
value = suggestion;
|
||||
filteredSuggestions = [];
|
||||
}
|
||||
|
||||
function handleOnInput() {
|
||||
if (value.length > 0 && !/^[A-Za-z0-9]*$/.test(value)) {
|
||||
keyError = 'Only alphanumeric characters are allowed';
|
||||
return;
|
||||
} else {
|
||||
keyError = undefined;
|
||||
}
|
||||
|
||||
filteredSuggestions = suggestions
|
||||
.filter((s) => s.includes(value.toLowerCase()))
|
||||
.slice(0, suggestionLimit);
|
||||
}
|
||||
|
||||
function handleKeydown(e: KeyboardEvent) {
|
||||
if (!isOpen) return;
|
||||
switch (e.key) {
|
||||
case 'ArrowDown':
|
||||
selectedIndex = Math.min(selectedIndex + 1, filteredSuggestions.length - 1);
|
||||
break;
|
||||
case 'ArrowUp':
|
||||
selectedIndex = Math.max(selectedIndex - 1, -1);
|
||||
break;
|
||||
case 'Enter':
|
||||
if (selectedIndex >= 0) {
|
||||
handleSuggestionClick(filteredSuggestions[selectedIndex]);
|
||||
}
|
||||
break;
|
||||
case 'Escape':
|
||||
isInputFocused = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let isOpen = $derived(filteredSuggestions.length > 0 && isInputFocused);
|
||||
|
||||
$effect(() => {
|
||||
// Reset selection when suggestions change
|
||||
if (filteredSuggestions) {
|
||||
selectedIndex = -1;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="grid w-full"
|
||||
role="combobox"
|
||||
onkeydown={handleKeydown}
|
||||
aria-controls="suggestion-list"
|
||||
aria-expanded={isOpen}
|
||||
tabindex="-1"
|
||||
>
|
||||
<Input
|
||||
{placeholder}
|
||||
bind:value
|
||||
oninput={handleOnInput}
|
||||
onfocus={() => (isInputFocused = true)}
|
||||
onblur={() => (isInputFocused = false)}
|
||||
/>
|
||||
{#if keyError}
|
||||
<p class="mt-1 text-sm text-red-500">{keyError}</p>
|
||||
{/if}
|
||||
<Popover.Root
|
||||
open={isOpen}
|
||||
disableFocusTrap
|
||||
openFocus={() => {}}
|
||||
closeOnOutsideClick={false}
|
||||
closeOnEscape={false}
|
||||
>
|
||||
<Popover.Trigger tabindex={-1} class="h-0 w-full" aria-hidden />
|
||||
<Popover.Content class="p-0" sideOffset={5} sameWidth>
|
||||
{#each filteredSuggestions as suggestion, index}
|
||||
<div
|
||||
role="button"
|
||||
tabindex="0"
|
||||
onmousedown={() => handleSuggestionClick(suggestion)}
|
||||
onkeydown={(e) => {
|
||||
if (e.key === 'Enter') handleSuggestionClick(suggestion);
|
||||
}}
|
||||
class="hover:bg-accent hover:text-accent-foreground relative flex w-full cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 {selectedIndex ===
|
||||
index
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
>
|
||||
{suggestion}
|
||||
</div>
|
||||
{/each}
|
||||
</Popover.Content>
|
||||
</Popover.Root>
|
||||
</div>
|
||||
75
frontend/src/lib/components/custom-claims-input.svelte
Normal file
75
frontend/src/lib/components/custom-claims-input.svelte
Normal file
@@ -0,0 +1,75 @@
|
||||
<script lang="ts">
|
||||
import FormInput from '$lib/components/form-input.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import CustomClaimService from '$lib/services/custom-claim-service';
|
||||
import type { CustomClaim } from '$lib/types/custom-claim.type';
|
||||
import { LucideMinus, LucidePlus } from 'lucide-svelte';
|
||||
import { onMount, type Snippet } from 'svelte';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import AutoCompleteInput from './auto-complete-input.svelte';
|
||||
|
||||
let {
|
||||
customClaims = $bindable(),
|
||||
error = $bindable(null),
|
||||
...restProps
|
||||
}: HTMLAttributes<HTMLDivElement> & {
|
||||
customClaims: CustomClaim[];
|
||||
error?: string | null;
|
||||
children?: Snippet;
|
||||
} = $props();
|
||||
|
||||
const limit = 20;
|
||||
|
||||
const customClaimService = new CustomClaimService();
|
||||
|
||||
let suggestions: string[] = $state([]);
|
||||
let filteredSuggestions: string[] = $derived(
|
||||
suggestions.filter(
|
||||
(suggestion) => !customClaims.some((customClaim) => customClaim.key === suggestion)
|
||||
)
|
||||
);
|
||||
|
||||
onMount(() => {
|
||||
customClaimService.getSuggestions().then((data) => (suggestions = data));
|
||||
});
|
||||
</script>
|
||||
|
||||
<div {...restProps}>
|
||||
<FormInput>
|
||||
<div class="flex flex-col gap-y-2">
|
||||
{#each customClaims as _, i}
|
||||
<div class="flex gap-x-2">
|
||||
<AutoCompleteInput
|
||||
placeholder="Key"
|
||||
suggestions={filteredSuggestions}
|
||||
bind:value={customClaims[i].key}
|
||||
/>
|
||||
<Input placeholder="Value" bind:value={customClaims[i].value} />
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
aria-label="Remove custom claim"
|
||||
on:click={() => (customClaims = customClaims.filter((_, index) => index !== i))}
|
||||
>
|
||||
<LucideMinus class="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</FormInput>
|
||||
{#if error}
|
||||
<p class="mt-1 text-sm text-red-500">{error}</p>
|
||||
{/if}
|
||||
{#if customClaims.length < limit}
|
||||
<Button
|
||||
class="mt-2"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
on:click={() => (customClaims = [...customClaims, { key: '', value: '' }])}
|
||||
>
|
||||
<LucidePlus class="mr-1 h-4 w-4" />
|
||||
{customClaims.length === 0 ? 'Add custom claim' : 'Add another'}
|
||||
</Button>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -16,7 +16,7 @@
|
||||
...restProps
|
||||
}: HTMLAttributes<HTMLDivElement> & {
|
||||
input?: FormInput<string | boolean | number>;
|
||||
label: string;
|
||||
label?: string;
|
||||
description?: string;
|
||||
disabled?: boolean;
|
||||
type?: 'text' | 'password' | 'email' | 'number' | 'checkbox';
|
||||
@@ -24,15 +24,17 @@
|
||||
children?: Snippet;
|
||||
} = $props();
|
||||
|
||||
const id = label.toLowerCase().replace(/ /g, '-');
|
||||
const id = label?.toLowerCase().replace(/ /g, '-');
|
||||
</script>
|
||||
|
||||
<div {...restProps}>
|
||||
<Label class="mb-0" for={id}>{label}</Label>
|
||||
{#if label}
|
||||
<Label class="mb-0" for={id}>{label}</Label>
|
||||
{/if}
|
||||
{#if description}
|
||||
<p class="text-muted-foreground mt-1 text-xs">{description}</p>
|
||||
{/if}
|
||||
<div class="mt-2">
|
||||
<div class={label || description ? 'mt-2' : ''}>
|
||||
{#if children}
|
||||
{@render children()}
|
||||
{:else if input}
|
||||
|
||||
17
frontend/src/lib/components/ui/popover/index.ts
Normal file
17
frontend/src/lib/components/ui/popover/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { Popover as PopoverPrimitive } from "bits-ui";
|
||||
import Content from "./popover-content.svelte";
|
||||
const Root = PopoverPrimitive.Root;
|
||||
const Trigger = PopoverPrimitive.Trigger;
|
||||
const Close = PopoverPrimitive.Close;
|
||||
|
||||
export {
|
||||
Root,
|
||||
Content,
|
||||
Trigger,
|
||||
Close,
|
||||
//
|
||||
Root as Popover,
|
||||
Content as PopoverContent,
|
||||
Trigger as PopoverTrigger,
|
||||
Close as PopoverClose,
|
||||
};
|
||||
@@ -0,0 +1,22 @@
|
||||
<script lang="ts">
|
||||
import { Popover as PopoverPrimitive } from "bits-ui";
|
||||
import { cn, flyAndScale } from "$lib/utils/style.js";
|
||||
|
||||
type $$Props = PopoverPrimitive.ContentProps;
|
||||
let className: $$Props["class"] = undefined;
|
||||
export let transition: $$Props["transition"] = flyAndScale;
|
||||
export let transitionConfig: $$Props["transitionConfig"] = undefined;
|
||||
export { className as class };
|
||||
</script>
|
||||
|
||||
<PopoverPrimitive.Content
|
||||
{transition}
|
||||
{transitionConfig}
|
||||
class={cn(
|
||||
"bg-popover text-popover-foreground z-50 w-72 rounded-md border p-4 shadow-md outline-none",
|
||||
className
|
||||
)}
|
||||
{...$$restProps}
|
||||
>
|
||||
<slot />
|
||||
</PopoverPrimitive.Content>
|
||||
@@ -73,8 +73,8 @@ export default class AppConfigService extends APIService {
|
||||
return true;
|
||||
} else if (value === 'false') {
|
||||
return false;
|
||||
} else if (!isNaN(Number(value))) {
|
||||
return Number(value);
|
||||
} else if (!isNaN(parseFloat(value))) {
|
||||
return parseFloat(value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
|
||||
19
frontend/src/lib/services/custom-claim-service.ts
Normal file
19
frontend/src/lib/services/custom-claim-service.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import type { CustomClaim } from '$lib/types/custom-claim.type';
|
||||
import APIService from './api-service';
|
||||
|
||||
export default class CustomClaimService extends APIService {
|
||||
async getSuggestions() {
|
||||
const res = await this.api.get('/custom-claims/suggestions');
|
||||
return res.data as string[];
|
||||
}
|
||||
|
||||
async updateUserCustomClaims(userId: string, claims: CustomClaim[]) {
|
||||
const res = await this.api.put(`/custom-claims/user/${userId}`, claims);
|
||||
return res.data as CustomClaim[];
|
||||
}
|
||||
|
||||
async updateUserGroupCustomClaims(userGroupId: string, claims: CustomClaim[]) {
|
||||
const res = await this.api.put(`/custom-claims/user-group/${userGroupId}`, claims);
|
||||
return res.data as CustomClaim[];
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
export type AppConfig = {
|
||||
appName: string;
|
||||
allowOwnAccountEdit: boolean;
|
||||
};
|
||||
|
||||
export type AllAppConfig = AppConfig & {
|
||||
|
||||
4
frontend/src/lib/types/custom-claim.type.ts
Normal file
4
frontend/src/lib/types/custom-claim.type.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export type CustomClaim = {
|
||||
key: string;
|
||||
value: string;
|
||||
};
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { CustomClaim } from './custom-claim.type';
|
||||
import type { User } from './user.type';
|
||||
|
||||
export type UserGroup = {
|
||||
@@ -5,6 +6,7 @@ export type UserGroup = {
|
||||
friendlyName: string;
|
||||
name: string;
|
||||
createdAt: string;
|
||||
customClaims: CustomClaim[];
|
||||
};
|
||||
|
||||
export type UserGroupWithUsers = UserGroup & {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import type { CustomClaim } from './custom-claim.type';
|
||||
|
||||
export type User = {
|
||||
id: string;
|
||||
username: string;
|
||||
@@ -5,6 +7,7 @@ export type User = {
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
isAdmin: boolean;
|
||||
customClaims: CustomClaim[];
|
||||
};
|
||||
|
||||
export type UserCreate = Omit<User, 'id'>;
|
||||
export type UserCreate = Omit<User, 'id' | 'customClaims'>;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import * as Card from '$lib/components/ui/card';
|
||||
import UserService from '$lib/services/user-service';
|
||||
import WebAuthnService from '$lib/services/webauthn-service';
|
||||
import appConfigStore from '$lib/stores/application-configuration-store';
|
||||
import type { Passkey } from '$lib/types/passkey.type';
|
||||
import type { UserCreate } from '$lib/types/user.type';
|
||||
import { axiosErrorToast, getWebauthnErrorMessage } from '$lib/utils/error-util';
|
||||
@@ -51,14 +52,16 @@
|
||||
<title>Account Settings</title>
|
||||
</svelte:head>
|
||||
|
||||
<Card.Root>
|
||||
<Card.Header>
|
||||
<Card.Title>Account Details</Card.Title>
|
||||
</Card.Header>
|
||||
<Card.Content>
|
||||
<AccountForm {account} callback={updateAccount} />
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
{#if $appConfigStore.allowOwnAccountEdit}
|
||||
<Card.Root>
|
||||
<Card.Header>
|
||||
<Card.Title>Account Details</Card.Title>
|
||||
</Card.Header>
|
||||
<Card.Content>
|
||||
<AccountForm {account} callback={updateAccount} />
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
{/if}
|
||||
|
||||
<Card.Root>
|
||||
<Card.Header>
|
||||
|
||||
@@ -21,13 +21,15 @@
|
||||
const updatedAppConfig = {
|
||||
appName: appConfig.appName,
|
||||
sessionDuration: appConfig.sessionDuration,
|
||||
emailsVerified: appConfig.emailsVerified
|
||||
emailsVerified: appConfig.emailsVerified,
|
||||
allowOwnAccountEdit: appConfig.allowOwnAccountEdit
|
||||
};
|
||||
|
||||
const formSchema = z.object({
|
||||
appName: z.string().min(2).max(30),
|
||||
sessionDuration: z.number().min(1).max(43200),
|
||||
emailsVerified: z.boolean()
|
||||
emailsVerified: z.boolean(),
|
||||
allowOwnAccountEdit: z.boolean()
|
||||
});
|
||||
|
||||
const { inputs, ...form } = createForm<typeof formSchema>(formSchema, updatedAppConfig);
|
||||
@@ -49,6 +51,17 @@
|
||||
description="The duration of a session in minutes before the user has to sign in again."
|
||||
bind:input={$inputs.sessionDuration}
|
||||
/>
|
||||
<div class="items-top mt-5 flex space-x-2">
|
||||
<Checkbox id="admin-privileges" bind:checked={$inputs.allowOwnAccountEdit.value} />
|
||||
<div class="grid gap-1.5 leading-none">
|
||||
<Label for="admin-privileges" class="mb-0 text-sm font-medium leading-none">
|
||||
Enable Self-Account Editing
|
||||
</Label>
|
||||
<p class="text-muted-foreground text-[0.8rem]">
|
||||
Whether the user should be able to edit their own account details.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="items-top mt-5 flex space-x-2">
|
||||
<Checkbox id="admin-privileges" bind:checked={$inputs.emailsVerified.value} />
|
||||
<div class="grid gap-1.5 leading-none">
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
<script lang="ts">
|
||||
import CustomClaimsInput from '$lib/components/custom-claims-input.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Card from '$lib/components/ui/card';
|
||||
import CustomClaimService from '$lib/services/custom-claim-service';
|
||||
import UserGroupService from '$lib/services/user-group-service';
|
||||
import UserService from '$lib/services/user-service';
|
||||
import type { UserGroupCreate } from '$lib/types/user-group.type';
|
||||
@@ -18,6 +20,7 @@
|
||||
|
||||
const userGroupService = new UserGroupService();
|
||||
const userService = new UserService();
|
||||
const customClaimService = new CustomClaimService();
|
||||
|
||||
async function updateUserGroup(updatedUserGroup: UserGroupCreate) {
|
||||
let success = true;
|
||||
@@ -40,6 +43,15 @@
|
||||
axiosErrorToast(e);
|
||||
});
|
||||
}
|
||||
|
||||
async function updateCustomClaims() {
|
||||
await customClaimService
|
||||
.updateUserGroupCustomClaims(userGroup.id, userGroup.customClaims)
|
||||
.then(() => toast.success('Custom claims updated successfully'))
|
||||
.catch((e) => {
|
||||
axiosErrorToast(e);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
|
||||
<svelte:head>
|
||||
@@ -53,7 +65,7 @@
|
||||
</div>
|
||||
<Card.Root>
|
||||
<Card.Header>
|
||||
<Card.Title>Meta data</Card.Title>
|
||||
<Card.Title>General</Card.Title>
|
||||
</Card.Header>
|
||||
|
||||
<Card.Content>
|
||||
@@ -76,3 +88,20 @@
|
||||
</div>
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
|
||||
<Card.Root>
|
||||
<Card.Header>
|
||||
<Card.Title>Custom Claims</Card.Title>
|
||||
<Card.Description>
|
||||
Custom claims are key-value pairs that can be used to store additional information about a
|
||||
user. These claims will be included in the ID token if the scope "profile" is requested.
|
||||
Custom claims defined on the user will be prioritized if there are conflicts.
|
||||
</Card.Description>
|
||||
</Card.Header>
|
||||
<Card.Content>
|
||||
<CustomClaimsInput bind:customClaims={userGroup.customClaims} />
|
||||
<div class="mt-5 flex justify-end">
|
||||
<Button onclick={updateCustomClaims} type="submit">Save</Button>
|
||||
</div>
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Card from '$lib/components/ui/card';
|
||||
import CustomClaimService from '$lib/services/custom-claim-service';
|
||||
import UserService from '$lib/services/user-service';
|
||||
import type { UserCreate } from '$lib/types/user.type';
|
||||
import { axiosErrorToast } from '$lib/utils/error-util';
|
||||
import { LucideChevronLeft } from 'lucide-svelte';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import CustomClaimsInput from '../../../../../lib/components/custom-claims-input.svelte';
|
||||
import UserForm from '../user-form.svelte';
|
||||
|
||||
let { data } = $props();
|
||||
let user = $state(data);
|
||||
|
||||
const userService = new UserService();
|
||||
const customClaimService = new CustomClaimService();
|
||||
|
||||
async function updateUser(updatedUser: UserCreate) {
|
||||
let success = true;
|
||||
@@ -24,6 +28,15 @@
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
async function updateCustomClaims() {
|
||||
await customClaimService
|
||||
.updateUserCustomClaims(user.id, user.customClaims)
|
||||
.then(() => toast.success('Custom claims updated successfully'))
|
||||
.catch((e) => {
|
||||
axiosErrorToast(e);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
|
||||
<svelte:head>
|
||||
@@ -37,10 +50,25 @@
|
||||
</div>
|
||||
<Card.Root>
|
||||
<Card.Header>
|
||||
<Card.Title>{user.firstName} {user.lastName}</Card.Title>
|
||||
<Card.Title>General</Card.Title>
|
||||
</Card.Header>
|
||||
|
||||
<Card.Content>
|
||||
<UserForm existingUser={user} callback={updateUser} />
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
|
||||
<Card.Root>
|
||||
<Card.Header>
|
||||
<Card.Title>Custom Claims</Card.Title>
|
||||
<Card.Description>
|
||||
Custom claims are key-value pairs that can be used to store additional information about a
|
||||
user. These claims will be included in the ID token if the scope "profile" is requested.
|
||||
</Card.Description>
|
||||
</Card.Header>
|
||||
<Card.Content>
|
||||
<CustomClaimsInput bind:customClaims={user.customClaims} />
|
||||
<div class="mt-5 flex justify-end">
|
||||
<Button onclick={updateCustomClaims} type="submit">Save</Button>
|
||||
</div>
|
||||
</Card.Content>
|
||||
</Card.Root>
|
||||
|
||||
@@ -24,7 +24,7 @@ test('Update account details fails with already taken email', async ({ page }) =
|
||||
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Email is already taken');
|
||||
await expect(page.getByRole('status')).toHaveText('Email is already in use');
|
||||
});
|
||||
|
||||
test('Update account details fails with already taken username', async ({ page }) => {
|
||||
@@ -34,7 +34,7 @@ test('Update account details fails with already taken username', async ({ page }
|
||||
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Username is already taken');
|
||||
await expect(page.getByRole('status')).toHaveText('Username is already in use');
|
||||
});
|
||||
|
||||
test('Add passkey to an account', async ({ page }) => {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { cleanupBackend } from './utils/cleanup.util';
|
||||
|
||||
test.beforeEach(cleanupBackend);
|
||||
|
||||
test('Create user group', async ({ page, baseURL }) => {
|
||||
test('Create user group', async ({ page }) => {
|
||||
await page.goto('/settings/admin/user-groups');
|
||||
const group = userGroups.humanResources;
|
||||
|
||||
@@ -15,8 +15,7 @@ test('Create user group', async ({ page, baseURL }) => {
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('User group created successfully');
|
||||
|
||||
const expectedRoute = new RegExp(`${baseURL}/settings/admin/user-groups/[a-f0-9-]+`);
|
||||
expect(page.url()).toMatch(expectedRoute);
|
||||
await page.waitForURL('/settings/admin/user-groups/*');
|
||||
|
||||
await expect(page.getByLabel('Friendly Name')).toHaveValue(group.friendlyName);
|
||||
await expect(page.getByLabel('Name', { exact: true })).toHaveValue(group.name);
|
||||
@@ -74,3 +73,39 @@ test('Delete user group', async ({ page }) => {
|
||||
await expect(page.getByRole('status')).toHaveText('User group deleted successfully');
|
||||
await expect(page.getByRole('row', { name: group.name })).not.toBeVisible();
|
||||
});
|
||||
|
||||
test('Update user group custom claims', async ({ page }) => {
|
||||
await page.goto(`/settings/admin/user-groups/${userGroups.designers.id}`);
|
||||
|
||||
// Add two custom claims
|
||||
await page.getByRole('button', { name: 'Add custom claim' }).click();
|
||||
|
||||
await page.getByPlaceholder('Key').fill('customClaim1');
|
||||
await page.getByPlaceholder('Value').fill('customClaim1_value');
|
||||
|
||||
await page.getByRole('button', { name: 'Add another' }).click();
|
||||
await page.getByPlaceholder('Key').nth(1).fill('customClaim2');
|
||||
await page.getByPlaceholder('Value').nth(1).fill('customClaim2_value');
|
||||
|
||||
await page.getByRole('button', { name: 'Save' }).nth(2).click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Custom claims updated successfully');
|
||||
|
||||
await page.reload();
|
||||
|
||||
// Check if custom claims are saved
|
||||
await expect(page.getByPlaceholder('Key').first()).toHaveValue('customClaim1');
|
||||
await expect(page.getByPlaceholder('Value').first()).toHaveValue('customClaim1_value');
|
||||
await expect(page.getByPlaceholder('Key').nth(1)).toHaveValue('customClaim2');
|
||||
await expect(page.getByPlaceholder('Value').nth(1)).toHaveValue('customClaim2_value');
|
||||
|
||||
// Remove one custom claim
|
||||
await page.getByLabel('Remove custom claim').first().click();
|
||||
await page.getByRole('button', { name: 'Save' }).nth(2).click();
|
||||
|
||||
await page.reload();
|
||||
|
||||
// Check if custom claim is removed
|
||||
await expect(page.getByPlaceholder('Key').first()).toHaveValue('customClaim2');
|
||||
await expect(page.getByPlaceholder('Value').first()).toHaveValue('customClaim2_value');
|
||||
});
|
||||
|
||||
@@ -32,7 +32,7 @@ test('Create user fails with already taken email', async ({ page }) => {
|
||||
await page.getByLabel('Username').fill(user.username);
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Email is already taken');
|
||||
await expect(page.getByRole('status')).toHaveText('Email is already in use');
|
||||
});
|
||||
|
||||
test('Create user fails with already taken username', async ({ page }) => {
|
||||
@@ -47,7 +47,7 @@ test('Create user fails with already taken username', async ({ page }) => {
|
||||
await page.getByLabel('Username').fill(users.tim.username);
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Username is already taken');
|
||||
await expect(page.getByRole('status')).toHaveText('Username is already in use');
|
||||
});
|
||||
|
||||
test('Create one time access token', async ({ page }) => {
|
||||
@@ -95,7 +95,7 @@ test('Update user', async ({ page }) => {
|
||||
await page.getByLabel('Last name').fill('Apple');
|
||||
await page.getByLabel('Email').fill('crack.apple@test.com');
|
||||
await page.getByLabel('Username').fill('crack');
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
await page.getByRole('button', { name: 'Save' }).first().click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('User updated successfully');
|
||||
});
|
||||
@@ -112,9 +112,9 @@ test('Update user fails with already taken email', async ({ page }) => {
|
||||
await page.getByRole('menuitem', { name: 'Edit' }).click();
|
||||
|
||||
await page.getByLabel('Email').fill(users.tim.email);
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
await page.getByRole('button', { name: 'Save' }).first().click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Email is already taken');
|
||||
await expect(page.getByRole('status')).toHaveText('Email is already in use');
|
||||
});
|
||||
|
||||
test('Update user fails with already taken username', async ({ page }) => {
|
||||
@@ -129,7 +129,43 @@ test('Update user fails with already taken username', async ({ page }) => {
|
||||
await page.getByRole('menuitem', { name: 'Edit' }).click();
|
||||
|
||||
await page.getByLabel('Username').fill(users.tim.username);
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
await page.getByRole('button', { name: 'Save' }).first().click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Username is already taken');
|
||||
await expect(page.getByRole('status')).toHaveText('Username is already in use');
|
||||
});
|
||||
|
||||
test('Update user custom claims', async ({ page }) => {
|
||||
await page.goto(`/settings/admin/users/${users.craig.id}`);
|
||||
|
||||
// Add two custom claims
|
||||
await page.getByRole('button', { name: 'Add custom claim' }).click();
|
||||
|
||||
await page.getByPlaceholder('Key').fill('customClaim1');
|
||||
await page.getByPlaceholder('Value').fill('customClaim1_value');
|
||||
|
||||
await page.getByRole('button', { name: 'Add another' }).click();
|
||||
await page.getByPlaceholder('Key').nth(1).fill('customClaim2');
|
||||
await page.getByPlaceholder('Value').nth(1).fill('customClaim2_value');
|
||||
|
||||
await page.getByRole('button', { name: 'Save' }).nth(1).click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Custom claims updated successfully');
|
||||
|
||||
await page.reload();
|
||||
|
||||
// Check if custom claims are saved
|
||||
await expect(page.getByPlaceholder('Key').first()).toHaveValue('customClaim1');
|
||||
await expect(page.getByPlaceholder('Value').first()).toHaveValue('customClaim1_value');
|
||||
await expect(page.getByPlaceholder('Key').nth(1)).toHaveValue('customClaim2');
|
||||
await expect(page.getByPlaceholder('Value').nth(1)).toHaveValue('customClaim2_value');
|
||||
|
||||
// Remove one custom claim
|
||||
await page.getByLabel('Remove custom claim').first().click();
|
||||
await page.getByRole('button', { name: 'Save' }).nth(1).click();
|
||||
|
||||
await page.reload();
|
||||
|
||||
// Check if custom claim is removed
|
||||
await expect(page.getByPlaceholder('Key').first()).toHaveValue('customClaim2');
|
||||
await expect(page.getByPlaceholder('Value').first()).toHaveValue('customClaim2_value');
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user