mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-09 23:02:56 +03:00
261 lines
6.4 KiB
Go
261 lines
6.4 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type CustomClaimService struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
|
|
return &CustomClaimService{db: db}
|
|
}
|
|
|
|
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
|
|
func isReservedClaim(key string) bool {
|
|
switch key {
|
|
case "given_name",
|
|
"family_name",
|
|
"name",
|
|
"email",
|
|
"preferred_username",
|
|
"display_name",
|
|
"groups",
|
|
TokenTypeClaim,
|
|
"sub",
|
|
"iss",
|
|
"aud",
|
|
"exp",
|
|
"iat",
|
|
"auth_time",
|
|
"nonce",
|
|
"acr",
|
|
"amr",
|
|
"azp",
|
|
"nbf",
|
|
"jti":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// idType is the type of the id used to identify the user or user group
|
|
type idType string
|
|
|
|
const (
|
|
UserID idType = "user_id"
|
|
UserGroupID idType = "user_group_id"
|
|
)
|
|
|
|
// UpdateCustomClaimsForUser updates the custom claims for a user
|
|
func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
updatedClaims, err := s.updateCustomClaimsInternal(ctx, UserID, userID, claims, tx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return updatedClaims, nil
|
|
}
|
|
|
|
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
|
|
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
updatedClaims, err := s.updateCustomClaimsInternal(ctx, UserGroupID, userGroupID, claims, tx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return updatedClaims, nil
|
|
}
|
|
|
|
// updateCustomClaimsInternal updates the custom claims for a user or user group within a transaction
|
|
func (s *CustomClaimService) updateCustomClaimsInternal(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto, tx *gorm.DB) ([]model.CustomClaim, error) {
|
|
// Check for duplicate keys in the claims slice
|
|
seenKeys := make(map[string]struct{})
|
|
for _, claim := range claims {
|
|
if _, ok := seenKeys[claim.Key]; ok {
|
|
return nil, &common.DuplicateClaimError{Key: claim.Key}
|
|
}
|
|
seenKeys[claim.Key] = struct{}{}
|
|
}
|
|
|
|
var existingClaims []model.CustomClaim
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where(string(idType), value).
|
|
Find(&existingClaims).
|
|
Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Delete claims that are not in the new list
|
|
for _, existingClaim := range existingClaims {
|
|
found := false
|
|
for _, claim := range claims {
|
|
if claim.Key == existingClaim.Key {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Delete(&existingClaim).
|
|
Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add or update claims
|
|
for _, claim := range claims {
|
|
if isReservedClaim(claim.Key) {
|
|
return nil, &common.ReservedClaimError{Key: claim.Key}
|
|
}
|
|
customClaim := model.CustomClaim{
|
|
Key: claim.Key,
|
|
Value: claim.Value,
|
|
}
|
|
|
|
switch idType {
|
|
case UserID:
|
|
customClaim.UserID = &value
|
|
case UserGroupID:
|
|
customClaim.UserGroupID = &value
|
|
}
|
|
|
|
// Update the claim if it already exists or create a new one
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Where(string(idType)+" = ? AND key = ?", value, claim.Key).
|
|
Assign(&customClaim).
|
|
FirstOrCreate(&model.CustomClaim{}).
|
|
Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Get the updated claims
|
|
var updatedClaims []model.CustomClaim
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Where(string(idType)+" = ?", value).
|
|
Find(&updatedClaims).
|
|
Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return updatedClaims, nil
|
|
}
|
|
|
|
func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
|
var customClaims []model.CustomClaim
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where("user_id = ?", userID).
|
|
Find(&customClaims).
|
|
Error
|
|
return customClaims, err
|
|
}
|
|
|
|
func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
|
var customClaims []model.CustomClaim
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where("user_group_id = ?", userGroupID).
|
|
Find(&customClaims).
|
|
Error
|
|
return customClaims, err
|
|
}
|
|
|
|
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
|
|
// prioritizing the user's claims over user group claims with the same key.
|
|
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
|
// Get the custom claims of the user
|
|
customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Store user's claims in a map to prioritize and prevent duplicates
|
|
claimsMap := make(map[string]model.CustomClaim)
|
|
for _, claim := range customClaims {
|
|
claimsMap[claim.Key] = claim
|
|
}
|
|
|
|
// Get all user groups of the user
|
|
var userGroupsOfUser []model.UserGroup
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Preload("CustomClaims").
|
|
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
|
|
Where("user_groups_users.user_id = ?", userID).
|
|
Find(&userGroupsOfUser).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Add only non-duplicate custom claims from user groups
|
|
for _, userGroup := range userGroupsOfUser {
|
|
for _, groupClaim := range userGroup.CustomClaims {
|
|
// Only add claim if it does not exist in the user's claims
|
|
if _, exists := claimsMap[groupClaim.Key]; !exists {
|
|
claimsMap[groupClaim.Key] = groupClaim
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert the claimsMap back to a slice
|
|
finalClaims := make([]model.CustomClaim, 0, len(claimsMap))
|
|
for _, claim := range claimsMap {
|
|
finalClaims = append(finalClaims, claim)
|
|
}
|
|
|
|
return finalClaims, nil
|
|
}
|
|
|
|
// GetSuggestions returns a list of custom claim keys that have been used before
|
|
func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) {
|
|
var customClaimsKeys []string
|
|
|
|
err := s.db.
|
|
WithContext(ctx).
|
|
Model(&model.CustomClaim{}).
|
|
Group("key").
|
|
Order("COUNT(*) DESC").
|
|
Pluck("key", &customClaimsKeys).Error
|
|
|
|
return customClaimsKeys, err
|
|
}
|