diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index bc9bdf00..66a57595 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -27,9 +27,12 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt group.GET("/users/:id", jwtAuthMiddleware.Add(true), uc.getUserHandler) group.POST("/users", jwtAuthMiddleware.Add(true), uc.createUserHandler) group.PUT("/users/:id", jwtAuthMiddleware.Add(true), uc.updateUserHandler) + group.GET("/users/:id/groups", jwtAuthMiddleware.Add(true), uc.getUserGroupsHandler) group.PUT("/users/me", jwtAuthMiddleware.Add(false), uc.updateCurrentUserHandler) group.DELETE("/users/:id", jwtAuthMiddleware.Add(true), uc.deleteUserHandler) + group.PUT("/users/:id/user-groups", jwtAuthMiddleware.Add(true), uc.updateUserGroups) + group.GET("/users/:id/profile-picture.png", uc.getUserProfilePictureHandler) group.GET("/users/me/profile-picture.png", jwtAuthMiddleware.Add(false), uc.getCurrentUserProfilePictureHandler) group.PUT("/users/:id/profile-picture", jwtAuthMiddleware.Add(true), uc.updateUserProfilePictureHandler) @@ -46,6 +49,23 @@ type UserController struct { appConfigService *service.AppConfigService } +func (uc *UserController) getUserGroupsHandler(c *gin.Context) { + userID := c.Param("id") + groups, err := uc.userService.GetUserGroups(userID) + if err != nil { + c.Error(err) + return + } + + var groupsDto []dto.UserGroupDtoWithUsers + if err := dto.MapStructList(groups, &groupsDto); err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, groupsDto) +} + func (uc *UserController) listUsersHandler(c *gin.Context) { searchTerm := c.Query("search") var sortedPaginationRequest utils.SortedPaginationRequest @@ -315,3 +335,25 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { c.JSON(http.StatusOK, userDto) } + +func (uc *UserController) updateUserGroups(c *gin.Context) { + var input dto.UserUpdateUserGroupDto + if err := c.ShouldBindJSON(&input); err != nil { + c.Error(err) + return + } + + user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds) + if err != nil { + c.Error(err) + return + } + + var userDto dto.UserDto + if err := dto.MapStruct(user, &userDto); err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, userDto) +} diff --git a/backend/internal/controller/user_group_controller.go b/backend/internal/controller/user_group_controller.go index 1e51e9f5..1c0dcbe5 100644 --- a/backend/internal/controller/user_group_controller.go +++ b/backend/internal/controller/user_group_controller.go @@ -139,7 +139,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) { return } - group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input) + group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs) if err != nil { c.Error(err) return diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go index 930766dc..e60cde1a 100644 --- a/backend/internal/dto/user_dto.go +++ b/backend/internal/dto/user_dto.go @@ -10,6 +10,7 @@ type UserDto struct { LastName string `json:"lastName"` IsAdmin bool `json:"isAdmin"` CustomClaims []CustomClaimDto `json:"customClaims"` + UserGroups []UserGroupDto `json:"userGroups"` LdapID *string `json:"ldapId"` } @@ -31,3 +32,7 @@ type OneTimeAccessEmailDto struct { Email string `json:"email" binding:"required,email"` RedirectPath string `json:"redirectPath"` } + +type UserUpdateUserGroupDto struct { + UserGroupIds []string `json:"userGroupIds" binding:"required"` +} diff --git a/backend/internal/dto/user_group_dto.go b/backend/internal/dto/user_group_dto.go index 13857cfb..e05c03a3 100644 --- a/backend/internal/dto/user_group_dto.go +++ b/backend/internal/dto/user_group_dto.go @@ -4,6 +4,15 @@ import ( datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" ) +type UserGroupDto struct { + ID string `json:"id"` + FriendlyName string `json:"friendlyName"` + Name string `json:"name"` + CustomClaims []CustomClaimDto `json:"customClaims"` + LdapID *string `json:"ldapId"` + CreatedAt datatype.DateTime `json:"createdAt"` +} + type UserGroupDtoWithUsers struct { ID string `json:"id"` FriendlyName string `json:"friendlyName"` diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go index 963964d1..b92a02bc 100644 --- a/backend/internal/service/ldap_service.go +++ b/backend/internal/service/ldap_service.go @@ -132,22 +132,18 @@ func (s *LdapService) SyncGroups() error { LdapID: value.GetAttributeValue(uniqueIdentifierAttribute), } - usersToAddDto := dto.UserGroupUpdateUsersDto{ - UserIDs: membersUserId, - } - if databaseGroup.ID == "" { newGroup, err := s.groupService.Create(syncGroup) if err != nil { log.Printf("Error syncing group %s: %s", syncGroup.Name, err) } else { - if _, err = s.groupService.UpdateUsers(newGroup.ID, usersToAddDto); err != nil { + if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil { log.Printf("Error syncing group %s: %s", syncGroup.Name, err) } } } else { _, err = s.groupService.Update(databaseGroup.ID, syncGroup, true) - _, err = s.groupService.UpdateUsers(databaseGroup.ID, usersToAddDto) + _, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId) if err != nil { log.Printf("Error syncing group %s: %s", syncGroup.Name, err) return err diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index ebbf3e35..2bfe8e32 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -103,16 +103,16 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow return group, nil } -func (s *UserGroupService) UpdateUsers(id string, input dto.UserGroupUpdateUsersDto) (group model.UserGroup, err error) { +func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) { group, err = s.Get(id) if err != nil { return model.UserGroup{}, err } - // Fetch the users based on UserIDs in input + // Fetch the users based on the userIds var users []model.User - if len(input.UserIDs) > 0 { - if err := s.db.Where("id IN (?)", input.UserIDs).Find(&users).Error; err != nil { + if len(userIds) > 0 { + if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil { return model.UserGroup{}, err } } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 077ab48d..ab044af6 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -3,8 +3,6 @@ package service import ( "errors" "fmt" - "github.com/google/uuid" - "github.com/pocket-id/pocket-id/backend/internal/utils/image" "io" "log" "net/url" @@ -12,6 +10,9 @@ import ( "strings" "time" + "github.com/google/uuid" + profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" @@ -48,7 +49,7 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils func (s *UserService) GetUser(userID string) (model.User, error) { var user model.User - err := s.db.Preload("CustomClaims").Where("id = ?", userID).First(&user).Error + err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error return user, err } @@ -83,6 +84,14 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error) return defaultPicture, int64(defaultPicture.Len()), nil } +func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) { + var user model.User + if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil { + return nil, err + } + return user.UserGroups, nil +} + func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error { // Validate the user ID to prevent directory traversal if err := uuid.Validate(userID); err != nil { @@ -269,6 +278,33 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg return oneTimeAccessToken.User, accessToken, nil } +func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) { + user, err = s.GetUser(id) + if err != nil { + return model.User{}, err + } + + // Fetch the groups based on userGroupIds + var groups []model.UserGroup + if len(userGroupIds) > 0 { + if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil { + return model.User{}, err + } + } + + // Replace the current groups with the new set of groups + if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil { + return model.User{}, err + } + + // Save the updated user + if err := s.db.Save(&user).Error; err != nil { + return model.User{}, err + } + + return user, nil +} + func (s *UserService) SetupInitialAdmin() (model.User, string, error) { var userCount int64 if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil { diff --git a/frontend/src/routes/settings/admin/oidc-clients/user-group-selection.svelte b/frontend/src/lib/components/user-group-selection.svelte similarity index 94% rename from frontend/src/routes/settings/admin/oidc-clients/user-group-selection.svelte rename to frontend/src/lib/components/user-group-selection.svelte index f4d37be8..0a68a88c 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/user-group-selection.svelte +++ b/frontend/src/lib/components/user-group-selection.svelte @@ -2,7 +2,6 @@ import AdvancedTable from '$lib/components/advanced-table.svelte'; import * as Table from '$lib/components/ui/table'; import UserGroupService from '$lib/services/user-group-service'; - import type { OidcClient } from '$lib/types/oidc.type'; import type { Paginated } from '$lib/types/pagination.type'; import type { UserGroup } from '$lib/types/user-group.type'; diff --git a/frontend/src/lib/services/user-service.ts b/frontend/src/lib/services/user-service.ts index d946340a..735c565a 100644 --- a/frontend/src/lib/services/user-service.ts +++ b/frontend/src/lib/services/user-service.ts @@ -1,4 +1,5 @@ import type { Paginated, SearchPaginationSortRequest } from '$lib/types/pagination.type'; +import type { UserGroup } from '$lib/types/user-group.type'; import type { User, UserCreate } from '$lib/types/user.type'; import APIService from './api-service'; @@ -25,6 +26,11 @@ export default class UserService extends APIService { return res.data as User; } + async getUserGroups(userId: string) { + const res = await this.api.get(`/users/${userId}/groups`); + return res.data as UserGroup[]; + } + async update(id: string, user: UserCreate) { const res = await this.api.put(`/users/${id}`, user); return res.data as User; @@ -69,4 +75,9 @@ export default class UserService extends APIService { async requestOneTimeAccessEmail(email: string, redirectPath?: string) { await this.api.post('/one-time-access-email', { email, redirectPath }); } + + async updateUserGroups(id: string, userGroupIds: string[]) { + const res = await this.api.put(`/users/${id}/user-groups`, { userGroupIds }); + return res.data as User; + } } diff --git a/frontend/src/lib/types/user.type.ts b/frontend/src/lib/types/user.type.ts index 5366e52c..05d2137e 100644 --- a/frontend/src/lib/types/user.type.ts +++ b/frontend/src/lib/types/user.type.ts @@ -1,4 +1,5 @@ import type { CustomClaim } from './custom-claim.type'; +import type { UserGroup } from './user-group.type'; export type User = { id: string; @@ -7,6 +8,7 @@ export type User = { firstName: string; lastName: string; isAdmin: boolean; + userGroups: UserGroup[]; customClaims: CustomClaim[]; ldapId?: string; }; diff --git a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte index 19c97848..001c3434 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte @@ -1,11 +1,13 @@