mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-08 06:12:56 +03:00
198 lines
5.7 KiB
Go
198 lines
5.7 KiB
Go
package service
|
|
|
|
import (
|
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Reserved claims
|
|
var reservedClaims = map[string]struct{}{
|
|
"given_name": {},
|
|
"family_name": {},
|
|
"name": {},
|
|
"email": {},
|
|
"preferred_username": {},
|
|
"groups": {},
|
|
"sub": {},
|
|
"iss": {},
|
|
"aud": {},
|
|
"exp": {},
|
|
"iat": {},
|
|
"auth_time": {},
|
|
"nonce": {},
|
|
"acr": {},
|
|
"amr": {},
|
|
"azp": {},
|
|
"nbf": {},
|
|
"jti": {},
|
|
}
|
|
|
|
type CustomClaimService struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
|
|
return &CustomClaimService{db: db}
|
|
}
|
|
|
|
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
|
|
func isReservedClaim(key string) bool {
|
|
_, ok := reservedClaims[key]
|
|
return ok
|
|
}
|
|
|
|
// idType is the type of the id used to identify the user or user group
|
|
type idType string
|
|
|
|
const (
|
|
UserID idType = "user_id"
|
|
UserGroupID idType = "user_group_id"
|
|
)
|
|
|
|
// UpdateCustomClaimsForUser updates the custom claims for a user
|
|
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
|
return s.updateCustomClaims(UserID, userID, claims)
|
|
}
|
|
|
|
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
|
|
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
|
return s.updateCustomClaims(UserGroupID, userGroupID, claims)
|
|
}
|
|
|
|
// updateCustomClaims updates the custom claims for a user or user group
|
|
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
|
// Check for duplicate keys in the claims slice
|
|
seenKeys := make(map[string]bool)
|
|
for _, claim := range claims {
|
|
if seenKeys[claim.Key] {
|
|
return nil, &common.DuplicateClaimError{Key: claim.Key}
|
|
}
|
|
seenKeys[claim.Key] = true
|
|
}
|
|
|
|
var existingClaims []model.CustomClaim
|
|
err := s.db.Where(string(idType), value).Find(&existingClaims).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Delete claims that are not in the new list
|
|
for _, existingClaim := range existingClaims {
|
|
found := false
|
|
for _, claim := range claims {
|
|
if claim.Key == existingClaim.Key {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
err = s.db.Delete(&existingClaim).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add or update claims
|
|
for _, claim := range claims {
|
|
if isReservedClaim(claim.Key) {
|
|
return nil, &common.ReservedClaimError{Key: claim.Key}
|
|
}
|
|
customClaim := model.CustomClaim{
|
|
Key: claim.Key,
|
|
Value: claim.Value,
|
|
}
|
|
|
|
if idType == UserID {
|
|
customClaim.UserID = &value
|
|
} else if idType == UserGroupID {
|
|
customClaim.UserGroupID = &value
|
|
}
|
|
|
|
// Update the claim if it already exists or create a new one
|
|
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Get the updated claims
|
|
var updatedClaims []model.CustomClaim
|
|
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return updatedClaims, nil
|
|
}
|
|
|
|
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) {
|
|
var customClaims []model.CustomClaim
|
|
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error
|
|
return customClaims, err
|
|
}
|
|
|
|
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) {
|
|
var customClaims []model.CustomClaim
|
|
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error
|
|
return customClaims, err
|
|
}
|
|
|
|
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
|
|
// prioritizing the user's claims over user group claims with the same key.
|
|
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) {
|
|
// Get the custom claims of the user
|
|
customClaims, err := s.GetCustomClaimsForUser(userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Store user's claims in a map to prioritize and prevent duplicates
|
|
claimsMap := make(map[string]model.CustomClaim)
|
|
for _, claim := range customClaims {
|
|
claimsMap[claim.Key] = claim
|
|
}
|
|
|
|
// Get all user groups of the user
|
|
var userGroupsOfUser []model.UserGroup
|
|
err = s.db.Preload("CustomClaims").
|
|
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
|
|
Where("user_groups_users.user_id = ?", userID).
|
|
Find(&userGroupsOfUser).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Add only non-duplicate custom claims from user groups
|
|
for _, userGroup := range userGroupsOfUser {
|
|
for _, groupClaim := range userGroup.CustomClaims {
|
|
// Only add claim if it does not exist in the user's claims
|
|
if _, exists := claimsMap[groupClaim.Key]; !exists {
|
|
claimsMap[groupClaim.Key] = groupClaim
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert the claimsMap back to a slice
|
|
finalClaims := make([]model.CustomClaim, 0, len(claimsMap))
|
|
for _, claim := range claimsMap {
|
|
finalClaims = append(finalClaims, claim)
|
|
}
|
|
|
|
return finalClaims, nil
|
|
}
|
|
|
|
// GetSuggestions returns a list of custom claim keys that have been used before
|
|
func (s *CustomClaimService) GetSuggestions() ([]string, error) {
|
|
var customClaimsKeys []string
|
|
|
|
err := s.db.Model(&model.CustomClaim{}).
|
|
Group("key").
|
|
Order("COUNT(*) DESC").
|
|
Pluck("key", &customClaimsKeys).Error
|
|
|
|
return customClaimsKeys, err
|
|
}
|