mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-21 01:11:33 +03:00
feat: add various improvements to the table component (#961)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
@@ -45,15 +45,11 @@ func NewApiKeyController(group *gin.RouterGroup, authMiddleware *middleware.Auth
|
||||
// @Success 200 {object} dto.Paginated[dto.ApiKeyDto]
|
||||
// @Router /api/api-keys [get]
|
||||
func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
|
||||
listRequestOptions := utils.ParseListRequestOptions(ctx)
|
||||
|
||||
userID := ctx.GetString("userID")
|
||||
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(ctx.Request.Context(), userID, sortedPaginationRequest)
|
||||
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(ctx.Request.Context(), userID, listRequestOptions)
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
|
||||
@@ -41,18 +41,12 @@ type AuditLogController struct {
|
||||
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
|
||||
// @Router /api/audit-logs [get]
|
||||
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
|
||||
err := c.ShouldBindQuery(&sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
userID := c.GetString("userID")
|
||||
|
||||
// Fetch audit logs for the user
|
||||
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(c.Request.Context(), userID, sortedPaginationRequest)
|
||||
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(c.Request.Context(), userID, listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -86,26 +80,12 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Param filters[userId] query string false "Filter by user ID"
|
||||
// @Param filters[event] query string false "Filter by event type"
|
||||
// @Param filters[clientName] query string false "Filter by client name"
|
||||
// @Param filters[location] query string false "Filter by location type (external or internal)"
|
||||
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
|
||||
// @Router /api/audit-logs/all [get]
|
||||
func (alc *AuditLogController) listAllAuditLogsHandler(c *gin.Context) {
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
var filters dto.AuditLogFilterDto
|
||||
if err := c.ShouldBindQuery(&filters); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
logs, pagination, err := alc.auditLogService.ListAllAuditLogs(c.Request.Context(), sortedPaginationRequest, filters)
|
||||
logs, pagination, err := alc.auditLogService.ListAllAuditLogs(c.Request.Context(), listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -403,13 +403,9 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
// @Router /api/oidc/clients [get]
|
||||
func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
searchTerm := c.Query("search")
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
clients, pagination, err := oc.oidcService.ListClients(c.Request.Context(), searchTerm, sortedPaginationRequest)
|
||||
clients, pagination, err := oc.oidcService.ListClients(c.Request.Context(), searchTerm, listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -685,12 +681,9 @@ func (oc *OidcController) listAuthorizedClientsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (oc *OidcController) listAuthorizedClients(c *gin.Context, userID string) {
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
authorizedClients, pagination, err := oc.oidcService.ListAuthorizedClients(c.Request.Context(), userID, sortedPaginationRequest)
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
authorizedClients, pagination, err := oc.oidcService.ListAuthorizedClients(c.Request.Context(), userID, listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -741,15 +734,11 @@ func (oc *OidcController) revokeOwnClientAuthorizationHandler(c *gin.Context) {
|
||||
// @Success 200 {object} dto.Paginated[dto.AccessibleOidcClientDto]
|
||||
// @Router /api/oidc/users/me/clients [get]
|
||||
func (oc *OidcController) listOwnAccessibleClientsHandler(c *gin.Context) {
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
userID := c.GetString("userID")
|
||||
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
clients, pagination, err := oc.oidcService.ListAccessibleOidcClients(c.Request.Context(), userID, sortedPaginationRequest)
|
||||
clients, pagination, err := oc.oidcService.ListAccessibleOidcClients(c.Request.Context(), userID, listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -104,13 +104,9 @@ func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
|
||||
// @Router /api/users [get]
|
||||
func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
searchTerm := c.Query("search")
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
users, pagination, err := uc.userService.ListUsers(c.Request.Context(), searchTerm, sortedPaginationRequest)
|
||||
users, pagination, err := uc.userService.ListUsers(c.Request.Context(), searchTerm, listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -574,13 +570,9 @@ func (uc *UserController) createSignupTokenHandler(c *gin.Context) {
|
||||
// @Success 200 {object} dto.Paginated[dto.SignupTokenDto]
|
||||
// @Router /api/signup-tokens [get]
|
||||
func (uc *UserController) listSignupTokensHandler(c *gin.Context) {
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
tokens, pagination, err := uc.userService.ListSignupTokens(c.Request.Context(), sortedPaginationRequest)
|
||||
tokens, pagination, err := uc.userService.ListSignupTokens(c.Request.Context(), listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -47,16 +47,10 @@ type UserGroupController struct {
|
||||
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
|
||||
// @Router /api/user-groups [get]
|
||||
func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
searchTerm := c.Query("search")
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
listRequestOptions := utils.ParseListRequestOptions(c)
|
||||
|
||||
groups, pagination, err := ugc.UserGroupService.List(ctx, searchTerm, sortedPaginationRequest)
|
||||
groups, pagination, err := ugc.UserGroupService.List(c, searchTerm, listRequestOptions)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -70,7 +64,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(ctx, group.ID)
|
||||
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(c.Request.Context(), group.ID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -17,10 +17,3 @@ type AuditLogDto struct {
|
||||
Username string `json:"username"`
|
||||
Data map[string]string `json:"data"`
|
||||
}
|
||||
|
||||
type AuditLogFilterDto struct {
|
||||
UserID string `form:"filters[userId]"`
|
||||
Event string `form:"filters[event]"`
|
||||
ClientName string `form:"filters[clientName]"`
|
||||
Location string `form:"filters[location]"`
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
type AuditLog struct {
|
||||
Base
|
||||
|
||||
Event AuditLogEvent `sortable:"true"`
|
||||
Event AuditLogEvent `sortable:"true" filterable:"true"`
|
||||
IpAddress *string `sortable:"true"`
|
||||
Country string `sortable:"true"`
|
||||
City string `sortable:"true"`
|
||||
@@ -17,7 +17,7 @@ type AuditLog struct {
|
||||
Username string `gorm:"-"`
|
||||
Data AuditLogData
|
||||
|
||||
UserID string
|
||||
UserID string `filterable:"true"`
|
||||
User User
|
||||
}
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ type OidcClient struct {
|
||||
LogoutCallbackURLs UrlList
|
||||
ImageType *string
|
||||
IsPublic bool
|
||||
PkceEnabled bool
|
||||
RequiresReauthentication bool
|
||||
PkceEnabled bool `filterable:"true"`
|
||||
RequiresReauthentication bool `filterable:"true"`
|
||||
Credentials OidcClientCredentials
|
||||
LaunchURL *string
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ type User struct {
|
||||
FirstName string `sortable:"true"`
|
||||
LastName string `sortable:"true"`
|
||||
DisplayName string `sortable:"true"`
|
||||
IsAdmin bool `sortable:"true"`
|
||||
IsAdmin bool `sortable:"true" filterable:"true"`
|
||||
Locale *string
|
||||
LdapID *string
|
||||
Disabled bool `sortable:"true"`
|
||||
Disabled bool `sortable:"true" filterable:"true"`
|
||||
|
||||
CustomClaims []CustomClaim
|
||||
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
|
||||
|
||||
@@ -25,14 +25,14 @@ func NewApiKeyService(db *gorm.DB, emailService *EmailService) *ApiKeyService {
|
||||
return &ApiKeyService{db: db, emailService: emailService}
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
|
||||
func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, listRequestOptions utils.ListRequestOptions) ([]model.ApiKey, utils.PaginationResponse, error) {
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Model(&model.ApiKey{})
|
||||
|
||||
var apiKeys []model.ApiKey
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
|
||||
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &apiKeys)
|
||||
if err != nil {
|
||||
return nil, utils.PaginationResponse{}, err
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"log/slog"
|
||||
|
||||
userAgentParser "github.com/mileusna/useragent"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
||||
@@ -136,14 +135,14 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
|
||||
}
|
||||
|
||||
// ListAuditLogsForUser retrieves all audit logs for a given user ID
|
||||
func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
|
||||
func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, listRequestOptions utils.ListRequestOptions) ([]model.AuditLog, utils.PaginationResponse, error) {
|
||||
var logs []model.AuditLog
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.AuditLog{}).
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
|
||||
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &logs)
|
||||
return logs, pagination, err
|
||||
}
|
||||
|
||||
@@ -152,7 +151,7 @@ func (s *AuditLogService) DeviceStringFromUserAgent(userAgent string) string {
|
||||
return ua.Name + " on " + ua.OS + " " + ua.OSVersion
|
||||
}
|
||||
|
||||
func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, sortedPaginationRequest utils.SortedPaginationRequest, filters dto.AuditLogFilterDto) ([]model.AuditLog, utils.PaginationResponse, error) {
|
||||
func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, listRequestOptions utils.ListRequestOptions) ([]model.AuditLog, utils.PaginationResponse, error) {
|
||||
var logs []model.AuditLog
|
||||
|
||||
query := s.db.
|
||||
@@ -160,33 +159,36 @@ func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, sortedPagination
|
||||
Preload("User").
|
||||
Model(&model.AuditLog{})
|
||||
|
||||
if filters.UserID != "" {
|
||||
query = query.Where("user_id = ?", filters.UserID)
|
||||
}
|
||||
if filters.Event != "" {
|
||||
query = query.Where("event = ?", filters.Event)
|
||||
}
|
||||
if filters.ClientName != "" {
|
||||
if clientName, ok := listRequestOptions.Filters["clientName"]; ok {
|
||||
dialect := s.db.Name()
|
||||
switch dialect {
|
||||
case "sqlite":
|
||||
query = query.Where("json_extract(data, '$.clientName') = ?", filters.ClientName)
|
||||
query = query.Where("json_extract(data, '$.clientName') IN ?", clientName)
|
||||
case "postgres":
|
||||
query = query.Where("data->>'clientName' = ?", filters.ClientName)
|
||||
query = query.Where("data->>'clientName' IN ?", clientName)
|
||||
default:
|
||||
return nil, utils.PaginationResponse{}, fmt.Errorf("unsupported database dialect: %s", dialect)
|
||||
}
|
||||
}
|
||||
if filters.Location != "" {
|
||||
switch filters.Location {
|
||||
case "external":
|
||||
query = query.Where("country != 'Internal Network'")
|
||||
case "internal":
|
||||
query = query.Where("country = 'Internal Network'")
|
||||
|
||||
if locations, ok := listRequestOptions.Filters["location"]; ok {
|
||||
mapped := make([]string, 0, len(locations))
|
||||
for _, v := range locations {
|
||||
if s, ok := v.(string); ok {
|
||||
switch s {
|
||||
case "internal":
|
||||
mapped = append(mapped, "Internal Network")
|
||||
case "external":
|
||||
mapped = append(mapped, "External Network")
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(mapped) > 0 {
|
||||
query = query.Where("country IN ?", mapped)
|
||||
}
|
||||
}
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
|
||||
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &logs)
|
||||
if err != nil {
|
||||
return nil, pagination, err
|
||||
}
|
||||
|
||||
@@ -692,7 +692,7 @@ func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) ListClients(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
|
||||
func (s *OidcService) ListClients(ctx context.Context, name string, listRequestOptions utils.ListRequestOptions) ([]model.OidcClient, utils.PaginationResponse, error) {
|
||||
var clients []model.OidcClient
|
||||
|
||||
query := s.db.
|
||||
@@ -705,17 +705,17 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
|
||||
}
|
||||
|
||||
// As allowedUserGroupsCount is not a column, we need to manually sort it
|
||||
if sortedPaginationRequest.Sort.Column == "allowedUserGroupsCount" && utils.IsValidSortDirection(sortedPaginationRequest.Sort.Direction) {
|
||||
if listRequestOptions.Sort.Column == "allowedUserGroupsCount" && utils.IsValidSortDirection(listRequestOptions.Sort.Direction) {
|
||||
query = query.Select("oidc_clients.*, COUNT(oidc_clients_allowed_user_groups.oidc_client_id)").
|
||||
Joins("LEFT JOIN oidc_clients_allowed_user_groups ON oidc_clients.id = oidc_clients_allowed_user_groups.oidc_client_id").
|
||||
Group("oidc_clients.id").
|
||||
Order("COUNT(oidc_clients_allowed_user_groups.oidc_client_id) " + sortedPaginationRequest.Sort.Direction)
|
||||
Order("COUNT(oidc_clients_allowed_user_groups.oidc_client_id) " + listRequestOptions.Sort.Direction)
|
||||
|
||||
response, err := utils.Paginate(sortedPaginationRequest.Pagination.Page, sortedPaginationRequest.Pagination.Limit, query, &clients)
|
||||
response, err := utils.Paginate(listRequestOptions.Pagination.Page, listRequestOptions.Pagination.Limit, query, &clients)
|
||||
return clients, response, err
|
||||
}
|
||||
|
||||
response, err := utils.PaginateAndSort(sortedPaginationRequest, query, &clients)
|
||||
response, err := utils.PaginateFilterAndSort(listRequestOptions, query, &clients)
|
||||
return clients, response, err
|
||||
}
|
||||
|
||||
@@ -1350,7 +1350,7 @@ func (s *OidcService) GetAllowedGroupsCountOfClient(ctx context.Context, id stri
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) ListAuthorizedClients(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.UserAuthorizedOidcClient, utils.PaginationResponse, error) {
|
||||
func (s *OidcService) ListAuthorizedClients(ctx context.Context, userID string, listRequestOptions utils.ListRequestOptions) ([]model.UserAuthorizedOidcClient, utils.PaginationResponse, error) {
|
||||
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
@@ -1359,7 +1359,7 @@ func (s *OidcService) ListAuthorizedClients(ctx context.Context, userID string,
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
var authorizedClients []model.UserAuthorizedOidcClient
|
||||
response, err := utils.PaginateAndSort(sortedPaginationRequest, query, &authorizedClients)
|
||||
response, err := utils.PaginateFilterAndSort(listRequestOptions, query, &authorizedClients)
|
||||
|
||||
return authorizedClients, response, err
|
||||
}
|
||||
@@ -1392,7 +1392,7 @@ func (s *OidcService) RevokeAuthorizedClient(ctx context.Context, userID string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]dto.AccessibleOidcClientDto, utils.PaginationResponse, error) {
|
||||
func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID string, listRequestOptions utils.ListRequestOptions) ([]dto.AccessibleOidcClientDto, utils.PaginationResponse, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
@@ -1439,13 +1439,13 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
|
||||
|
||||
// Handle custom sorting for lastUsedAt column
|
||||
var response utils.PaginationResponse
|
||||
if sortedPaginationRequest.Sort.Column == "lastUsedAt" && utils.IsValidSortDirection(sortedPaginationRequest.Sort.Direction) {
|
||||
if listRequestOptions.Sort.Column == "lastUsedAt" && utils.IsValidSortDirection(listRequestOptions.Sort.Direction) {
|
||||
query = query.
|
||||
Joins("LEFT JOIN user_authorized_oidc_clients ON oidc_clients.id = user_authorized_oidc_clients.client_id AND user_authorized_oidc_clients.user_id = ?", userID).
|
||||
Order("user_authorized_oidc_clients.last_used_at " + sortedPaginationRequest.Sort.Direction + " NULLS LAST")
|
||||
Order("user_authorized_oidc_clients.last_used_at " + listRequestOptions.Sort.Direction + " NULLS LAST")
|
||||
}
|
||||
|
||||
response, err = utils.PaginateAndSort(sortedPaginationRequest, query, &clients)
|
||||
response, err = utils.PaginateFilterAndSort(listRequestOptions, query, &clients)
|
||||
if err != nil {
|
||||
return nil, utils.PaginationResponse{}, err
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG
|
||||
return &UserGroupService{db: db, appConfigService: appConfigService}
|
||||
}
|
||||
|
||||
func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
|
||||
func (s *UserGroupService) List(ctx context.Context, name string, listRequestOptions utils.ListRequestOptions) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Preload("CustomClaims").
|
||||
@@ -32,17 +32,14 @@ func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginati
|
||||
}
|
||||
|
||||
// As userCount is not a column we need to manually sort it
|
||||
if sortedPaginationRequest.Sort.Column == "userCount" && utils.IsValidSortDirection(sortedPaginationRequest.Sort.Direction) {
|
||||
if listRequestOptions.Sort.Column == "userCount" && utils.IsValidSortDirection(listRequestOptions.Sort.Direction) {
|
||||
query = query.Select("user_groups.*, COUNT(user_groups_users.user_id)").
|
||||
Joins("LEFT JOIN user_groups_users ON user_groups.id = user_groups_users.user_group_id").
|
||||
Group("user_groups.id").
|
||||
Order("COUNT(user_groups_users.user_id) " + sortedPaginationRequest.Sort.Direction)
|
||||
|
||||
response, err := utils.Paginate(sortedPaginationRequest.Pagination.Page, sortedPaginationRequest.Pagination.Limit, query, &groups)
|
||||
return groups, response, err
|
||||
Order("COUNT(user_groups_users.user_id) " + listRequestOptions.Sort.Direction)
|
||||
}
|
||||
|
||||
response, err = utils.PaginateAndSort(sortedPaginationRequest, query, &groups)
|
||||
response, err = utils.PaginateFilterAndSort(listRequestOptions, query, &groups)
|
||||
return groups, response, err
|
||||
}
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
|
||||
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, listRequestOptions utils.ListRequestOptions) ([]model.User, utils.PaginationResponse, error) {
|
||||
var users []model.User
|
||||
query := s.db.WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
@@ -60,7 +60,7 @@ func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPa
|
||||
searchPattern, searchPattern, searchPattern, searchPattern)
|
||||
}
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &users)
|
||||
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &users)
|
||||
|
||||
return users, pagination, err
|
||||
}
|
||||
@@ -794,11 +794,11 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
|
||||
return user, accessToken, nil
|
||||
}
|
||||
|
||||
func (s *UserService) ListSignupTokens(ctx context.Context, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.SignupToken, utils.PaginationResponse, error) {
|
||||
func (s *UserService) ListSignupTokens(ctx context.Context, listRequestOptions utils.ListRequestOptions) ([]model.SignupToken, utils.PaginationResponse, error) {
|
||||
var tokens []model.SignupToken
|
||||
query := s.db.WithContext(ctx).Model(&model.SignupToken{})
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &tokens)
|
||||
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &tokens)
|
||||
return tokens, pagination, err
|
||||
}
|
||||
|
||||
|
||||
205
backend/internal/utils/list_request_util.go
Normal file
205
backend/internal/utils/list_request_util.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type PaginationResponse struct {
|
||||
TotalPages int64 `json:"totalPages"`
|
||||
TotalItems int64 `json:"totalItems"`
|
||||
CurrentPage int `json:"currentPage"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
}
|
||||
|
||||
type ListRequestOptions struct {
|
||||
Pagination struct {
|
||||
Page int `form:"pagination[page]"`
|
||||
Limit int `form:"pagination[limit]"`
|
||||
} `form:"pagination"`
|
||||
Sort struct {
|
||||
Column string `form:"sort[column]"`
|
||||
Direction string `form:"sort[direction]"`
|
||||
} `form:"sort"`
|
||||
Filters map[string][]any
|
||||
}
|
||||
|
||||
type FieldMeta struct {
|
||||
ColumnName string
|
||||
IsSortable bool
|
||||
IsFilterable bool
|
||||
}
|
||||
|
||||
func ParseListRequestOptions(ctx *gin.Context) (listRequestOptions ListRequestOptions) {
|
||||
if err := ctx.ShouldBindQuery(&listRequestOptions); err != nil {
|
||||
return listRequestOptions
|
||||
}
|
||||
|
||||
listRequestOptions.Filters = parseNestedFilters(ctx)
|
||||
return listRequestOptions
|
||||
}
|
||||
|
||||
func PaginateFilterAndSort(params ListRequestOptions, query *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||
meta := extractModelMetadata(result)
|
||||
|
||||
query = applyFilters(params.Filters, query, meta)
|
||||
query = applySorting(params.Sort.Column, params.Sort.Direction, query, meta)
|
||||
|
||||
return Paginate(params.Pagination.Page, params.Pagination.Limit, query, result)
|
||||
}
|
||||
|
||||
func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
} else if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
|
||||
var totalItems int64
|
||||
if err := query.Count(&totalItems).Error; err != nil {
|
||||
return PaginationResponse{}, err
|
||||
}
|
||||
|
||||
totalPages := (totalItems + int64(pageSize) - 1) / int64(pageSize)
|
||||
if totalItems == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
if int64(page) > totalPages {
|
||||
page = int(totalPages)
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
if err := query.Offset(offset).Limit(pageSize).Find(result).Error; err != nil {
|
||||
return PaginationResponse{}, err
|
||||
}
|
||||
|
||||
return PaginationResponse{
|
||||
TotalPages: totalPages,
|
||||
TotalItems: totalItems,
|
||||
CurrentPage: page,
|
||||
ItemsPerPage: pageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NormalizeSortDirection(direction string) string {
|
||||
d := strings.ToLower(strings.TrimSpace(direction))
|
||||
if d != "asc" && d != "desc" {
|
||||
return "asc"
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func IsValidSortDirection(direction string) bool {
|
||||
d := strings.ToLower(strings.TrimSpace(direction))
|
||||
return d == "asc" || d == "desc"
|
||||
}
|
||||
|
||||
// parseNestedFilters handles ?filters[field][0]=val1&filters[field][1]=val2
|
||||
func parseNestedFilters(ctx *gin.Context) map[string][]any {
|
||||
result := make(map[string][]any)
|
||||
query := ctx.Request.URL.Query()
|
||||
|
||||
for key, values := range query {
|
||||
if !strings.HasPrefix(key, "filters[") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Keys can be "filters[field]" or "filters[field][0]"
|
||||
raw := strings.TrimPrefix(key, "filters[")
|
||||
// Take everything up to the first closing bracket
|
||||
if idx := strings.IndexByte(raw, ']'); idx != -1 {
|
||||
field := raw[:idx]
|
||||
for _, v := range values {
|
||||
result[field] = append(result[field], ConvertStringToType(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// applyFilters applies filtering to the GORM query based on the provided filters
|
||||
func applyFilters(filters map[string][]any, query *gorm.DB, meta map[string]FieldMeta) *gorm.DB {
|
||||
for key, values := range filters {
|
||||
if key == "" || len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := CapitalizeFirstLetter(key)
|
||||
fieldMeta, ok := meta[fieldName]
|
||||
if !ok || !fieldMeta.IsFilterable {
|
||||
continue
|
||||
}
|
||||
|
||||
query = query.Where(fieldMeta.ColumnName+" IN ?", values)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// applySorting applies sorting to the GORM query based on the provided column and direction
|
||||
func applySorting(sortColumn string, sortDirection string, query *gorm.DB, meta map[string]FieldMeta) *gorm.DB {
|
||||
fieldName := CapitalizeFirstLetter(sortColumn)
|
||||
fieldMeta, ok := meta[fieldName]
|
||||
if !ok || !fieldMeta.IsSortable {
|
||||
return query
|
||||
}
|
||||
|
||||
sortDirection = NormalizeSortDirection(sortDirection)
|
||||
|
||||
query = query.Clauses(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{
|
||||
{Column: clause.Column{Name: fieldMeta.ColumnName}, Desc: sortDirection == "desc"},
|
||||
},
|
||||
})
|
||||
return query
|
||||
}
|
||||
|
||||
// extractModelMetadata extracts FieldMeta from the model struct using reflection
|
||||
func extractModelMetadata(model interface{}) map[string]FieldMeta {
|
||||
meta := make(map[string]FieldMeta)
|
||||
|
||||
// Unwrap pointers and slices to get the element struct type
|
||||
t := reflect.TypeOf(model)
|
||||
for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice {
|
||||
t = t.Elem()
|
||||
if t == nil {
|
||||
return meta
|
||||
}
|
||||
}
|
||||
|
||||
// recursive parser that merges fields from embedded structs
|
||||
var parseStruct func(reflect.Type)
|
||||
parseStruct = func(st reflect.Type) {
|
||||
for i := 0; i < st.NumField(); i++ {
|
||||
field := st.Field(i)
|
||||
ft := field.Type
|
||||
|
||||
// If the field is an embedded/anonymous struct, recurse into it
|
||||
if field.Anonymous && ft.Kind() == reflect.Struct {
|
||||
parseStruct(ft)
|
||||
continue
|
||||
}
|
||||
|
||||
// Normal field: record metadata
|
||||
name := field.Name
|
||||
meta[name] = FieldMeta{
|
||||
ColumnName: CamelCaseToSnakeCase(name),
|
||||
IsSortable: field.Tag.Get("sortable") == "true",
|
||||
IsFilterable: field.Tag.Get("filterable") == "true",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parseStruct(t)
|
||||
return meta
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type PaginationResponse struct {
|
||||
TotalPages int64 `json:"totalPages"`
|
||||
TotalItems int64 `json:"totalItems"`
|
||||
CurrentPage int `json:"currentPage"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
}
|
||||
|
||||
type SortedPaginationRequest struct {
|
||||
Pagination struct {
|
||||
Page int `form:"pagination[page]"`
|
||||
Limit int `form:"pagination[limit]"`
|
||||
} `form:"pagination"`
|
||||
Sort struct {
|
||||
Column string `form:"sort[column]"`
|
||||
Direction string `form:"sort[direction]"`
|
||||
} `form:"sort"`
|
||||
}
|
||||
|
||||
func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||
pagination := sortedPaginationRequest.Pagination
|
||||
sort := sortedPaginationRequest.Sort
|
||||
|
||||
capitalizedSortColumn := CapitalizeFirstLetter(sort.Column)
|
||||
|
||||
sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn)
|
||||
isSortable, _ := strconv.ParseBool(sortField.Tag.Get("sortable"))
|
||||
|
||||
sort.Direction = NormalizeSortDirection(sort.Direction)
|
||||
|
||||
if sortFieldFound && isSortable {
|
||||
columnName := CamelCaseToSnakeCase(sort.Column)
|
||||
query = query.Clauses(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{
|
||||
{Column: clause.Column{Name: columnName}, Desc: sort.Direction == "desc"},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return Paginate(pagination.Page, pagination.Limit, query, result)
|
||||
}
|
||||
|
||||
func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
} else if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
var totalItems int64
|
||||
if err := query.Count(&totalItems).Error; err != nil {
|
||||
return PaginationResponse{}, err
|
||||
}
|
||||
|
||||
if err := query.Offset(offset).Limit(pageSize).Find(result).Error; err != nil {
|
||||
return PaginationResponse{}, err
|
||||
}
|
||||
|
||||
totalPages := (totalItems + int64(pageSize) - 1) / int64(pageSize)
|
||||
if totalItems == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
return PaginationResponse{
|
||||
TotalPages: totalPages,
|
||||
TotalItems: totalItems,
|
||||
CurrentPage: page,
|
||||
ItemsPerPage: pageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NormalizeSortDirection(direction string) string {
|
||||
d := strings.ToLower(strings.TrimSpace(direction))
|
||||
if d != "asc" && d != "desc" {
|
||||
return "asc"
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func IsValidSortDirection(direction string) bool {
|
||||
d := strings.ToLower(strings.TrimSpace(direction))
|
||||
return d == "asc" || d == "desc"
|
||||
}
|
||||
@@ -81,26 +81,21 @@ func CapitalizeFirstLetter(str string) string {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func CamelCaseToSnakeCase(str string) string {
|
||||
result := strings.Builder{}
|
||||
result.Grow(int(float32(len(str)) * 1.1))
|
||||
for i, r := range str {
|
||||
if unicode.IsUpper(r) && i > 0 {
|
||||
result.WriteByte('_')
|
||||
}
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
}
|
||||
return result.String()
|
||||
var (
|
||||
reAcronymBoundary = regexp.MustCompile(`([A-Z]+)([A-Z][a-z])`) // ABCd -> AB_Cd
|
||||
reLowerToUpper = regexp.MustCompile(`([a-z0-9])([A-Z])`) // aB -> a_B
|
||||
)
|
||||
|
||||
func CamelCaseToSnakeCase(s string) string {
|
||||
s = reAcronymBoundary.ReplaceAllString(s, "${1}_${2}")
|
||||
s = reLowerToUpper.ReplaceAllString(s, "${1}_${2}")
|
||||
return strings.ToLower(s)
|
||||
}
|
||||
|
||||
var camelCaseToScreamingSnakeCaseRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
|
||||
func CamelCaseToScreamingSnakeCase(s string) string {
|
||||
// Insert underscores before uppercase letters (except the first one)
|
||||
snake := camelCaseToScreamingSnakeCaseRe.ReplaceAllString(s, `${1}_${2}`)
|
||||
|
||||
// Convert to uppercase
|
||||
return strings.ToUpper(snake)
|
||||
s = reAcronymBoundary.ReplaceAllString(s, "${1}_${2}")
|
||||
s = reLowerToUpper.ReplaceAllString(s, "${1}_${2}")
|
||||
return strings.ToUpper(s)
|
||||
}
|
||||
|
||||
// GetFirstCharacter returns the first non-whitespace character of the string, correctly handling Unicode
|
||||
|
||||
@@ -86,9 +86,9 @@ func TestCamelCaseToSnakeCase(t *testing.T) {
|
||||
{"simple camelCase", "camelCase", "camel_case"},
|
||||
{"PascalCase", "PascalCase", "pascal_case"},
|
||||
{"multipleWordsInCamelCase", "multipleWordsInCamelCase", "multiple_words_in_camel_case"},
|
||||
{"consecutive uppercase", "HTTPRequest", "h_t_t_p_request"},
|
||||
{"consecutive uppercase", "HTTPRequest", "http_request"},
|
||||
{"single lowercase word", "word", "word"},
|
||||
{"single uppercase word", "WORD", "w_o_r_d"},
|
||||
{"single uppercase word", "WORD", "word"},
|
||||
{"with numbers", "camel123Case", "camel123_case"},
|
||||
{"with numbers in middle", "model2Name", "model2_name"},
|
||||
{"mixed case", "iPhone6sPlus", "i_phone6s_plus"},
|
||||
@@ -104,6 +104,34 @@ func TestCamelCaseToSnakeCase(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCamelCaseToScreamingSnakeCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"empty string", "", ""},
|
||||
{"simple camelCase", "camelCase", "CAMEL_CASE"},
|
||||
{"PascalCase", "PascalCase", "PASCAL_CASE"},
|
||||
{"multipleWordsInCamelCase", "multipleWordsInCamelCase", "MULTIPLE_WORDS_IN_CAMEL_CASE"},
|
||||
{"consecutive uppercase", "HTTPRequest", "HTTP_REQUEST"},
|
||||
{"single lowercase word", "word", "WORD"},
|
||||
{"single uppercase word", "WORD", "WORD"},
|
||||
{"with numbers", "camel123Case", "CAMEL123_CASE"},
|
||||
{"with numbers in middle", "model2Name", "MODEL2_NAME"},
|
||||
{"mixed case", "iPhone6sPlus", "I_PHONE6S_PLUS"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CamelCaseToScreamingSnakeCase(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("CamelCaseToScreamingSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFirstCharacter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
35
backend/internal/utils/type_util.go
Normal file
35
backend/internal/utils/type_util.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ConvertStringToType attempts to convert a string to bool, int, or float.
|
||||
func ConvertStringToType(value string) any {
|
||||
v := strings.TrimSpace(value)
|
||||
if v == "" {
|
||||
return v
|
||||
}
|
||||
|
||||
// Try bool
|
||||
if v == "true" {
|
||||
return true
|
||||
}
|
||||
if v == "false" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try int
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
|
||||
// Try float
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
|
||||
// Default: string
|
||||
return v
|
||||
}
|
||||
37
backend/internal/utils/type_util_test.go
Normal file
37
backend/internal/utils/type_util_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertStringToType(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected any
|
||||
}{
|
||||
{"true", true},
|
||||
{"false", false},
|
||||
{" true ", true},
|
||||
{" false ", false},
|
||||
{"42", 42},
|
||||
{" 42 ", 42},
|
||||
{"3.14", 3.14},
|
||||
{" 3.14 ", 3.14},
|
||||
{"hello", "hello"},
|
||||
{" hello ", "hello"},
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := ConvertStringToType(tt.input)
|
||||
if result != tt.expected {
|
||||
if f, ok := tt.expected.(float64); ok {
|
||||
if rf, ok := result.(float64); ok && rf == f {
|
||||
continue
|
||||
}
|
||||
}
|
||||
t.Errorf("ConvertStringToType(%q) = %#v (type %T), want %#v (type %T)", tt.input, result, result, tt.expected, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user