mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 02:03:01 +03:00
fix: use transactions when operations involve multiple database queries (#392)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
c810fec8c4
commit
ec626ee797
@@ -1,19 +1,23 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
)
|
||||
|
||||
func Bootstrap() {
|
||||
ctx := context.TODO()
|
||||
|
||||
initApplicationImages()
|
||||
|
||||
migrateConfigDBConnstring()
|
||||
|
||||
db := newDatabase()
|
||||
appConfigService := service.NewAppConfigService(db)
|
||||
appConfigService := service.NewAppConfigService(ctx, db)
|
||||
|
||||
migrateKey()
|
||||
|
||||
initRouter(db, appConfigService)
|
||||
initRouter(ctx, db, appConfigService)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
// This is used to register additional controllers for tests
|
||||
var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService)
|
||||
|
||||
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
func initRouter(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
// Set the appropriate Gin mode based on the environment
|
||||
switch common.EnvConfig.AppEnv {
|
||||
case "production":
|
||||
@@ -36,10 +37,10 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
// Initialize services
|
||||
emailService, err := service.NewEmailService(appConfigService, db)
|
||||
if err != nil {
|
||||
log.Fatalf("Unable to create email service: %s", err)
|
||||
log.Fatalf("Unable to create email service: %v", err)
|
||||
}
|
||||
|
||||
geoLiteService := service.NewGeoLiteService()
|
||||
geoLiteService := service.NewGeoLiteService(ctx)
|
||||
auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService)
|
||||
jwtService := service.NewJwtService(appConfigService)
|
||||
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
|
||||
@@ -57,9 +58,9 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
r.Use(middleware.NewErrorHandlerMiddleware().Add())
|
||||
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
|
||||
|
||||
job.RegisterLdapJobs(ldapService, appConfigService)
|
||||
job.RegisterDbCleanupJobs(db)
|
||||
job.RegisterFileCleanupJobs(db)
|
||||
job.RegisterLdapJobs(ctx, ldapService, appConfigService)
|
||||
job.RegisterDbCleanupJobs(ctx, db)
|
||||
job.RegisterFileCleanupJobs(ctx, db)
|
||||
|
||||
// Initialize middleware for specific routes
|
||||
authMiddleware := middleware.NewAuthMiddleware(apiKeyService, jwtService)
|
||||
|
||||
@@ -17,7 +17,7 @@ type AlreadyInUseError struct {
|
||||
}
|
||||
|
||||
func (e *AlreadyInUseError) Error() string {
|
||||
return fmt.Sprintf("%s is already in use", e.Property)
|
||||
return e.Property + " is already in use"
|
||||
}
|
||||
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest)
|
||||
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(ctx.Request.Context(), userID, sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
@@ -87,7 +87,7 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input)
|
||||
apiKey, token, err := c.apiKeyService.CreateApiKey(ctx.Request.Context(), userID, input)
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
@@ -116,7 +116,7 @@ func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
|
||||
userID := ctx.GetString("userID")
|
||||
apiKeyID := ctx.Param("id")
|
||||
|
||||
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil {
|
||||
if err := c.apiKeyService.RevokeApiKey(ctx.Request.Context(), userID, apiKeyID); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -61,7 +60,7 @@ type AppConfigController struct {
|
||||
// @Failure 500 {object} object "{"error": "error message"}"
|
||||
// @Router /application-configuration [get]
|
||||
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
|
||||
configuration, err := acc.appConfigService.ListAppConfig(false)
|
||||
configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), false)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -73,7 +72,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, configVariablesDto)
|
||||
c.JSON(http.StatusOK, configVariablesDto)
|
||||
}
|
||||
|
||||
// listAllAppConfigHandler godoc
|
||||
@@ -86,7 +85,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
|
||||
// @Security BearerAuth
|
||||
// @Router /application-configuration/all [get]
|
||||
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
|
||||
configuration, err := acc.appConfigService.ListAppConfig(true)
|
||||
configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), true)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -98,7 +97,7 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, configVariablesDto)
|
||||
c.JSON(http.StatusOK, configVariablesDto)
|
||||
}
|
||||
|
||||
// updateAppConfigHandler godoc
|
||||
@@ -118,7 +117,7 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input)
|
||||
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -253,7 +252,7 @@ func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
|
||||
|
||||
// getImage is a helper function to serve image files
|
||||
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
|
||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
|
||||
imagePath := common.EnvConfig.UploadPath + "/application-images/" + name + "." + imageType
|
||||
mimeType := utils.GetImageMimeType(imageType)
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
@@ -268,7 +267,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
|
||||
return
|
||||
}
|
||||
|
||||
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
|
||||
err = acc.appConfigService.UpdateImage(c.Request.Context(), file, imageName, oldImageType)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -285,7 +284,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
|
||||
// @Security BearerAuth
|
||||
// @Router /api/application-configuration/sync-ldap [post]
|
||||
func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
|
||||
err := acc.ldapService.SyncAll()
|
||||
err := acc.ldapService.SyncAll(c.Request.Context())
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -304,7 +303,7 @@ func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
|
||||
func (acc *AppConfigController) testEmailHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
|
||||
err := acc.emailService.SendTestEmail(userID)
|
||||
err := acc.emailService.SendTestEmail(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -42,7 +42,9 @@ type AuditLogController struct {
|
||||
// @Router /api/audit-logs [get]
|
||||
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
|
||||
err := c.ShouldBindQuery(&sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -50,7 +52,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
|
||||
// Fetch audit logs for the user
|
||||
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, sortedPaginationRequest)
|
||||
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(c.Request.Context(), userID, sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -41,7 +41,7 @@ type CustomClaimController struct {
|
||||
// @Security BearerAuth
|
||||
// @Router /api/custom-claims/suggestions [get]
|
||||
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
|
||||
claims, err := ccc.customClaimService.GetSuggestions()
|
||||
claims, err := ccc.customClaimService.GetSuggestions(c.Request.Context())
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -69,7 +69,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex
|
||||
}
|
||||
|
||||
userId := c.Param("userId")
|
||||
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input)
|
||||
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(c.Request.Context(), userId, input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -104,7 +104,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.C
|
||||
}
|
||||
|
||||
userGroupId := c.Param("userGroupId")
|
||||
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userGroupId, input)
|
||||
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(c.Request.Context(), userGroupId, input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -69,7 +69,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
|
||||
code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -100,7 +100,7 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
|
||||
return
|
||||
}
|
||||
|
||||
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope)
|
||||
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(c.Request.Context(), input.ClientID, c.GetString("userID"), input.Scope)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -153,6 +153,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens(
|
||||
c.Request.Context(),
|
||||
input.Code,
|
||||
input.GrantType,
|
||||
clientID,
|
||||
@@ -216,7 +217,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
_ = c.Error(&common.TokenInvalidError{})
|
||||
return
|
||||
}
|
||||
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientID[0])
|
||||
claims, err := oc.oidcService.GetUserClaimsForClient(c.Request.Context(), userID, clientID[0])
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -254,7 +255,7 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
callbackURL, err := oc.oidcService.ValidateEndSession(input, c.GetString("userID"))
|
||||
callbackURL, err := oc.oidcService.ValidateEndSession(c.Request.Context(), input, c.GetString("userID"))
|
||||
if err != nil {
|
||||
// If the validation fails, the user has to confirm the logout manually and doesn't get redirected
|
||||
log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err)
|
||||
@@ -300,7 +301,7 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
|
||||
// @Router /api/oidc/clients/{id}/meta [get]
|
||||
func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
|
||||
clientId := c.Param("id")
|
||||
client, err := oc.oidcService.GetClient(clientId)
|
||||
client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -327,7 +328,7 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
|
||||
// @Router /api/oidc/clients/{id} [get]
|
||||
func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
clientId := c.Param("id")
|
||||
client, err := oc.oidcService.GetClient(clientId)
|
||||
client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -363,7 +364,7 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
clients, pagination, err := oc.oidcService.ListClients(searchTerm, sortedPaginationRequest)
|
||||
clients, pagination, err := oc.oidcService.ListClients(c.Request.Context(), searchTerm, sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -398,7 +399,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
|
||||
client, err := oc.oidcService.CreateClient(c.Request.Context(), input, c.GetString("userID"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -422,7 +423,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id} [delete]
|
||||
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClient(c.Param("id"))
|
||||
err := oc.oidcService.DeleteClient(c.Request.Context(), c.Param("id"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -449,7 +450,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
|
||||
client, err := oc.oidcService.UpdateClient(c.Request.Context(), c.Param("id"), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -474,7 +475,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/secret [post]
|
||||
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
|
||||
secret, err := oc.oidcService.CreateClientSecret(c.Request.Context(), c.Param("id"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -494,7 +495,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
// @Success 200 {file} binary "Logo image"
|
||||
// @Router /api/oidc/clients/{id}/logo [get]
|
||||
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
|
||||
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -521,7 +522,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
|
||||
err = oc.oidcService.UpdateClientLogo(c.Request.Context(), c.Param("id"), file)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -539,7 +540,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/logo [delete]
|
||||
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
|
||||
err := oc.oidcService.DeleteClientLogo(c.Request.Context(), c.Param("id"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -566,7 +567,7 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input)
|
||||
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Request.Context(), c.Param("id"), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -65,7 +65,7 @@ type UserController struct {
|
||||
// @Router /api/users/{id}/groups [get]
|
||||
func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
groups, err := uc.userService.GetUserGroups(userID)
|
||||
groups, err := uc.userService.GetUserGroups(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -99,7 +99,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
users, pagination, err := uc.userService.ListUsers(searchTerm, sortedPaginationRequest)
|
||||
users, pagination, err := uc.userService.ListUsers(c.Request.Context(), searchTerm, sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -125,7 +125,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
// @Success 200 {object} dto.UserDto
|
||||
// @Router /api/users/{id} [get]
|
||||
func (uc *UserController) getUserHandler(c *gin.Context) {
|
||||
user, err := uc.userService.GetUser(c.Param("id"))
|
||||
user, err := uc.userService.GetUser(c.Request.Context(), c.Param("id"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -147,7 +147,7 @@ func (uc *UserController) getUserHandler(c *gin.Context) {
|
||||
// @Success 200 {object} dto.UserDto
|
||||
// @Router /api/users/me [get]
|
||||
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
||||
user, err := uc.userService.GetUser(c.GetString("userID"))
|
||||
user, err := uc.userService.GetUser(c.Request.Context(), c.GetString("userID"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -170,7 +170,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
||||
// @Success 204 "No Content"
|
||||
// @Router /api/users/{id} [delete]
|
||||
func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
||||
if err := uc.userService.DeleteUser(c.Param("id"), false); err != nil {
|
||||
if err := uc.userService.DeleteUser(c.Request.Context(), c.Param("id"), false); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func (uc *UserController) createUserHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := uc.userService.CreateUser(input)
|
||||
user, err := uc.userService.CreateUser(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -245,7 +245,7 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
|
||||
func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
|
||||
picture, size, err := uc.userService.GetProfilePicture(userID)
|
||||
picture, size, err := uc.userService.GetProfilePicture(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -332,7 +332,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bo
|
||||
if own {
|
||||
input.UserID = c.GetString("userID")
|
||||
}
|
||||
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
|
||||
token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, input.ExpiresAt)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -364,7 +364,7 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath)
|
||||
err := uc.userService.RequestOneTimeAccessEmail(c.Request.Context(), input.Email, input.RedirectPath)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -381,7 +381,7 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
|
||||
// @Success 200 {object} dto.UserDto
|
||||
// @Router /api/one-time-access-token/{token} [post]
|
||||
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent())
|
||||
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Request.Context(), c.Param("token"), c.ClientIP(), c.Request.UserAgent())
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -406,7 +406,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
// @Success 200 {object} dto.UserDto
|
||||
// @Router /api/one-time-access-token/setup [post]
|
||||
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
||||
user, token, err := uc.userService.SetupInitialAdmin()
|
||||
user, token, err := uc.userService.SetupInitialAdmin(c.Request.Context())
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -439,7 +439,7 @@ func (uc *UserController) updateUserGroups(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds)
|
||||
user, err := uc.userService.UpdateUserGroups(c.Request.Context(), c.Param("id"), input.UserGroupIds)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -469,7 +469,7 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||
userID = c.Param("id")
|
||||
}
|
||||
|
||||
user, err := uc.userService.UpdateUser(userID, input, updateOwnUser, false)
|
||||
user, err := uc.userService.UpdateUser(c.Request.Context(), userID, input, updateOwnUser, false)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -47,6 +47,8 @@ type UserGroupController struct {
|
||||
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
|
||||
// @Router /api/user-groups [get]
|
||||
func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
searchTerm := c.Query("search")
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
@@ -54,7 +56,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
groups, pagination, err := ugc.UserGroupService.List(searchTerm, sortedPaginationRequest)
|
||||
groups, pagination, err := ugc.UserGroupService.List(ctx, searchTerm, sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -68,7 +70,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID)
|
||||
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(ctx, group.ID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -93,7 +95,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
// @Security BearerAuth
|
||||
// @Router /api/user-groups/{id} [get]
|
||||
func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
group, err := ugc.UserGroupService.Get(c.Param("id"))
|
||||
group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -125,7 +127,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
group, err := ugc.UserGroupService.Create(input)
|
||||
group, err := ugc.UserGroupService.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -158,7 +160,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
group, err := ugc.UserGroupService.Update(c.Param("id"), input, false)
|
||||
group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input, false)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -184,7 +186,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
// @Security BearerAuth
|
||||
// @Router /api/user-groups/{id} [delete]
|
||||
func (ugc *UserGroupController) delete(c *gin.Context) {
|
||||
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil {
|
||||
if err := ugc.UserGroupService.Delete(c.Request.Context(), c.Param("id")); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -210,7 +212,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs)
|
||||
group, err := ugc.UserGroupService.UpdateUsers(c.Request.Context(), c.Param("id"), input.UserIDs)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -37,7 +37,7 @@ type WebauthnController struct {
|
||||
|
||||
func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
options, err := wc.webAuthnService.BeginRegistration(userID)
|
||||
options, err := wc.webAuthnService.BeginRegistration(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -55,7 +55,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
userID := c.GetString("userID")
|
||||
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
|
||||
credential, err := wc.webAuthnService.VerifyRegistration(c.Request.Context(), sessionID, userID, c.Request)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -71,7 +71,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
||||
options, err := wc.webAuthnService.BeginLogin()
|
||||
options, err := wc.webAuthnService.BeginLogin(c.Request.Context())
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -94,7 +94,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
|
||||
user, token, err := wc.webAuthnService.VerifyLogin(c.Request.Context(), sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -114,7 +114,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
||||
|
||||
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
credentials, err := wc.webAuthnService.ListCredentials(userID)
|
||||
credentials, err := wc.webAuthnService.ListCredentials(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -133,7 +133,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
credentialID := c.Param("id")
|
||||
|
||||
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
|
||||
err := wc.webAuthnService.DeleteCredential(c.Request.Context(), userID, credentialID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -152,7 +152,7 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
|
||||
credential, err := wc.webAuthnService.UpdateCredential(c.Request.Context(), userID, credentialID, input.Name)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RegisterDbCleanupJobs(db *gorm.DB) {
|
||||
func RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) {
|
||||
scheduler, err := gocron.NewScheduler()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create a new scheduler: %s", err)
|
||||
@@ -18,11 +20,11 @@ func RegisterDbCleanupJobs(db *gorm.DB) {
|
||||
|
||||
jobs := &DbCleanupJobs{db: db}
|
||||
|
||||
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
|
||||
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
|
||||
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
|
||||
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
|
||||
registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
|
||||
registerJob(ctx, scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
|
||||
registerJob(ctx, scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
|
||||
registerJob(ctx, scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
|
||||
registerJob(ctx, scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
|
||||
registerJob(ctx, scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
|
||||
scheduler.Start()
|
||||
}
|
||||
|
||||
@@ -31,26 +33,41 @@ type DbCleanupJobs struct {
|
||||
}
|
||||
|
||||
// ClearWebauthnSessions deletes WebAuthn sessions that have expired
|
||||
func (j *DbCleanupJobs) clearWebauthnSessions() error {
|
||||
return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
|
||||
func (j *DbCleanupJobs) clearWebauthnSessions(ctx context.Context) error {
|
||||
return j.db.
|
||||
WithContext(ctx).
|
||||
Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).
|
||||
Error
|
||||
}
|
||||
|
||||
// ClearOneTimeAccessTokens deletes one-time access tokens that have expired
|
||||
func (j *DbCleanupJobs) clearOneTimeAccessTokens() error {
|
||||
return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
|
||||
func (j *DbCleanupJobs) clearOneTimeAccessTokens(ctx context.Context) error {
|
||||
return j.db.
|
||||
WithContext(ctx).
|
||||
Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
|
||||
Error
|
||||
}
|
||||
|
||||
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
|
||||
func (j *DbCleanupJobs) clearOidcAuthorizationCodes() error {
|
||||
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
|
||||
func (j *DbCleanupJobs) clearOidcAuthorizationCodes(ctx context.Context) error {
|
||||
return j.db.
|
||||
WithContext(ctx).
|
||||
Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).
|
||||
Error
|
||||
}
|
||||
|
||||
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
|
||||
func (j *DbCleanupJobs) clearOidcRefreshTokens() error {
|
||||
return j.db.Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
|
||||
func (j *DbCleanupJobs) clearOidcRefreshTokens(ctx context.Context) error {
|
||||
return j.db.
|
||||
WithContext(ctx).
|
||||
Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
|
||||
Error
|
||||
}
|
||||
|
||||
// ClearAuditLogs deletes audit logs older than 90 days
|
||||
func (j *DbCleanupJobs) clearAuditLogs() error {
|
||||
return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error
|
||||
func (j *DbCleanupJobs) clearAuditLogs(ctx context.Context) error {
|
||||
return j.db.
|
||||
WithContext(ctx).
|
||||
Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).
|
||||
Error
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -8,12 +9,13 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RegisterFileCleanupJobs(db *gorm.DB) {
|
||||
func RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB) {
|
||||
scheduler, err := gocron.NewScheduler()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create a new scheduler: %s", err)
|
||||
@@ -21,7 +23,7 @@ func RegisterFileCleanupJobs(db *gorm.DB) {
|
||||
|
||||
jobs := &FileCleanupJobs{db: db}
|
||||
|
||||
registerJob(scheduler, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures)
|
||||
registerJob(ctx, scheduler, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures)
|
||||
|
||||
scheduler.Start()
|
||||
}
|
||||
@@ -31,9 +33,13 @@ type FileCleanupJobs struct {
|
||||
}
|
||||
|
||||
// ClearUnusedDefaultProfilePictures deletes default profile pictures that don't match any user's initials
|
||||
func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures() error {
|
||||
func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context) error {
|
||||
var users []model.User
|
||||
if err := j.db.Find(&users).Error; err != nil {
|
||||
err := j.db.
|
||||
WithContext(ctx).
|
||||
Find(&users).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch users: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
|
||||
func registerJob(ctx context.Context, scheduler gocron.Scheduler, name string, interval string, job func(ctx context.Context) error) {
|
||||
_, err := scheduler.NewJob(
|
||||
gocron.CronJob(interval, false),
|
||||
gocron.NewTask(job),
|
||||
gocron.WithContext(ctx),
|
||||
gocron.WithEventListeners(
|
||||
gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) {
|
||||
log.Printf("Job %q run successfully", name)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
@@ -12,28 +13,30 @@ type LdapJobs struct {
|
||||
appConfigService *service.AppConfigService
|
||||
}
|
||||
|
||||
func RegisterLdapJobs(ldapService *service.LdapService, appConfigService *service.AppConfigService) {
|
||||
func RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) {
|
||||
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
|
||||
|
||||
scheduler, err := gocron.NewScheduler()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create a new scheduler: %s", err)
|
||||
log.Fatalf("Failed to create a new scheduler: %v", err)
|
||||
}
|
||||
|
||||
// Register the job to run every hour
|
||||
registerJob(scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap)
|
||||
registerJob(ctx, scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap)
|
||||
|
||||
// Run the job immediately on startup
|
||||
if err := jobs.syncLdap(); err != nil {
|
||||
log.Printf("Failed to sync LDAP: %s", err)
|
||||
err = jobs.syncLdap(ctx)
|
||||
if err != nil {
|
||||
log.Printf("Failed to sync LDAP: %v", err)
|
||||
}
|
||||
|
||||
scheduler.Start()
|
||||
}
|
||||
|
||||
func (j *LdapJobs) syncLdap() error {
|
||||
if j.appConfigService.DbConfig.LdapEnabled.IsTrue() {
|
||||
return j.ldapService.SyncAll()
|
||||
}
|
||||
func (j *LdapJobs) syncLdap(ctx context.Context) error {
|
||||
if !j.appConfigService.DbConfig.LdapEnabled.IsTrue() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return j.ldapService.SyncAll(ctx)
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func (m *ApiKeyAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
|
||||
func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) {
|
||||
apiKey := c.GetHeader("X-API-KEY")
|
||||
|
||||
user, err := m.apiKeyService.ValidateApiKey(apiKey)
|
||||
user, err := m.apiKeyService.ValidateApiKey(c.Request.Context(), apiKey)
|
||||
if err != nil {
|
||||
return "", false, &common.NotSignedInError{}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type ApiKeyService struct {
|
||||
@@ -22,8 +23,11 @@ func NewApiKeyService(db *gorm.DB) *ApiKeyService {
|
||||
return &ApiKeyService{db: db}
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
|
||||
query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{})
|
||||
func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Model(&model.ApiKey{})
|
||||
|
||||
var apiKeys []model.ApiKey
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
|
||||
@@ -34,7 +38,7 @@ func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils
|
||||
return apiKeys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
|
||||
func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
|
||||
// Check if expiration is in the future
|
||||
if !input.ExpiresAt.ToTime().After(time.Now()) {
|
||||
return model.ApiKey{}, "", &common.APIKeyExpirationDateError{}
|
||||
@@ -54,7 +58,11 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&apiKey).Error; err != nil {
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Create(&apiKey).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.ApiKey{}, "", err
|
||||
}
|
||||
|
||||
@@ -62,29 +70,44 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
|
||||
return apiKey, token, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error {
|
||||
func (s *ApiKeyService) RevokeApiKey(ctx context.Context, userID, apiKeyID string) error {
|
||||
var apiKey model.ApiKey
|
||||
if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Where("id = ? AND user_id = ?", apiKeyID, userID).
|
||||
Delete(&apiKey).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return &common.APIKeyNotFoundError{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Delete(&apiKey).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
|
||||
func (s *ApiKeyService) ValidateApiKey(ctx context.Context, apiKey string) (model.User, error) {
|
||||
if apiKey == "" {
|
||||
return model.User{}, &common.NoAPIKeyProvidedError{}
|
||||
}
|
||||
|
||||
var key model.ApiKey
|
||||
now := time.Now()
|
||||
hashedKey := utils.CreateSha256Hash(apiKey)
|
||||
|
||||
if err := s.db.Preload("User").Where("key = ? AND expires_at > ?",
|
||||
hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil {
|
||||
|
||||
var key model.ApiKey
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.ApiKey{}).
|
||||
Clauses(clause.Returning{}).
|
||||
Where("key = ? AND expires_at > ?", hashedKey, datatype.DateTime(now)).
|
||||
Updates(&model.ApiKey{
|
||||
LastUsedAt: utils.Ptr(datatype.DateTime(now)),
|
||||
}).
|
||||
Preload("User").
|
||||
First(&key).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return model.User{}, &common.InvalidAPIKeyError{}
|
||||
}
|
||||
@@ -92,12 +115,5 @@ func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
// Update last used time
|
||||
now := datatype.DateTime(time.Now())
|
||||
key.LastUsedAt = &now
|
||||
if err := s.db.Save(&key).Error; err != nil {
|
||||
log.Printf("Failed to update last used time: %v", err)
|
||||
}
|
||||
|
||||
return key.User, nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
@@ -19,12 +20,14 @@ type AppConfigService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAppConfigService(db *gorm.DB) *AppConfigService {
|
||||
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
|
||||
service := &AppConfigService{
|
||||
DbConfig: &defaultDbConfig,
|
||||
db: db,
|
||||
}
|
||||
if err := service.InitDbConfig(); err != nil {
|
||||
|
||||
err := service.InitDbConfig(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize app config service: %v", err)
|
||||
}
|
||||
|
||||
@@ -197,17 +200,24 @@ var defaultDbConfig = model.AppConfig{
|
||||
},
|
||||
}
|
||||
|
||||
func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
|
||||
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
|
||||
if common.EnvConfig.UiConfigDisabled {
|
||||
return nil, &common.UiConfigDisabledError{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var err error
|
||||
|
||||
rt := reflect.ValueOf(input).Type()
|
||||
rv := reflect.ValueOf(input)
|
||||
|
||||
var savedConfigVariables []model.AppConfigVariable
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
savedConfigVariables := make([]model.AppConfigVariable, 0, rt.NumField())
|
||||
for i := range rt.NumField() {
|
||||
field := rt.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
value := rv.FieldByName(field.Name).String()
|
||||
@@ -220,32 +230,47 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
|
||||
}
|
||||
|
||||
var appConfigVariable model.AppConfigVariable
|
||||
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
|
||||
tx.Rollback()
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&appConfigVariable, "key = ? AND is_internal = false", key).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
appConfigVariable.Value = value
|
||||
if err := tx.Save(&appConfigVariable).Error; err != nil {
|
||||
tx.Rollback()
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&appConfigVariable).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedConfigVariables = append(savedConfigVariables, appConfigVariable)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.LoadDbConfigFromDb(); err != nil {
|
||||
err = s.LoadDbConfigFromDb()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return savedConfigVariables, nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
|
||||
key := fmt.Sprintf("%sImageType", imageName)
|
||||
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
|
||||
func (s *AppConfigService) updateImageType(ctx context.Context, imageName string, fileType string) error {
|
||||
key := imageName + "ImageType"
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.AppConfigVariable{}).
|
||||
Where("key = ?", key).
|
||||
Update("value", fileType).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -253,14 +278,17 @@ func (s *AppConfigService) UpdateImageType(imageName string, fileType string) er
|
||||
return s.LoadDbConfigFromDb()
|
||||
}
|
||||
|
||||
func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) {
|
||||
var configuration []model.AppConfigVariable
|
||||
var err error
|
||||
|
||||
func (s *AppConfigService) ListAppConfig(ctx context.Context, showAll bool) (configuration []model.AppConfigVariable, err error) {
|
||||
if showAll {
|
||||
err = s.db.Find(&configuration).Error
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Find(&configuration).
|
||||
Error
|
||||
} else {
|
||||
err = s.db.Find(&configuration, "is_public = true").Error
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Find(&configuration, "is_public = true").
|
||||
Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -271,7 +299,6 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
|
||||
if common.EnvConfig.UiConfigDisabled {
|
||||
// Set the value to the environment variable if the UI config is disabled
|
||||
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
|
||||
|
||||
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" {
|
||||
// Set the value to the default value if it is empty
|
||||
configuration[i].Value = configuration[i].DefaultValue
|
||||
@@ -281,7 +308,7 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
|
||||
return configuration, nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error {
|
||||
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
|
||||
fileType := utils.GetFileExtension(uploadedFile.Filename)
|
||||
mimeType := utils.GetImageMimeType(fileType)
|
||||
if mimeType == "" {
|
||||
@@ -290,19 +317,22 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
|
||||
|
||||
// Delete the old image if it has a different file type
|
||||
if fileType != oldImageType {
|
||||
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
|
||||
if err := os.Remove(oldImagePath); err != nil {
|
||||
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
|
||||
err = os.Remove(oldImagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
|
||||
if err := utils.SaveFile(uploadedFile, imagePath); err != nil {
|
||||
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
|
||||
err = utils.SaveFile(uploadedFile, imagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the file type in the database
|
||||
if err := s.UpdateImageType(imageName, fileType); err != nil {
|
||||
err = s.updateImageType(ctx, imageName, fileType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -312,33 +342,58 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
|
||||
// InitDbConfig creates the default configuration values in the database if they do not exist,
|
||||
// updates existing configurations if they differ from the default, and deletes any configurations
|
||||
// that are not in the default configuration.
|
||||
func (s *AppConfigService) InitDbConfig() error {
|
||||
func (s *AppConfigService) InitDbConfig(ctx context.Context) (err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// Reflect to get the underlying value of DbConfig and its default configuration
|
||||
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
|
||||
defaultKeys := make(map[string]struct{})
|
||||
|
||||
// Iterate over the fields of DbConfig
|
||||
for i := 0; i < defaultConfigReflectValue.NumField(); i++ {
|
||||
for i := range defaultConfigReflectValue.NumField() {
|
||||
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
|
||||
|
||||
defaultKeys[defaultConfigVar.Key] = struct{}{}
|
||||
|
||||
var storedConfigVar model.AppConfigVariable
|
||||
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&storedConfigVar, "key = ?", defaultConfigVar.Key).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// If the configuration does not exist, create it
|
||||
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&defaultConfigVar).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update existing configuration if it differs from the default
|
||||
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
|
||||
if storedConfigVar.Type != defaultConfigVar.Type ||
|
||||
storedConfigVar.IsPublic != defaultConfigVar.IsPublic ||
|
||||
storedConfigVar.IsInternal != defaultConfigVar.IsInternal ||
|
||||
storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
|
||||
// Set values
|
||||
storedConfigVar.Type = defaultConfigVar.Type
|
||||
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
|
||||
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
|
||||
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue
|
||||
if err := s.db.Save(&storedConfigVar).Error; err != nil {
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&storedConfigVar).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -346,29 +401,54 @@ func (s *AppConfigService) InitDbConfig() error {
|
||||
|
||||
// Delete any configurations not in the default keys
|
||||
var allConfigVars []model.AppConfigVariable
|
||||
if err := s.db.Find(&allConfigVars).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&allConfigVars).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, config := range allConfigVars {
|
||||
if _, exists := defaultKeys[config.Key]; !exists {
|
||||
if err := s.db.Delete(&config).Error; err != nil {
|
||||
if _, exists := defaultKeys[config.Key]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&config).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the changes
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.LoadDbConfigFromDb()
|
||||
|
||||
// Reload the configuration
|
||||
err = s.LoadDbConfigFromDb()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
|
||||
func (s *AppConfigService) LoadDbConfigFromDb() error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
|
||||
|
||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||
for i := range dbConfigReflectValue.NumField() {
|
||||
dbConfigField := dbConfigReflectValue.Field(i)
|
||||
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
|
||||
var storedConfigVar model.AppConfigVariable
|
||||
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
||||
err := tx.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -379,10 +459,10 @@ func (s *AppConfigService) LoadDbConfigFromDb() error {
|
||||
}
|
||||
|
||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {
|
||||
|
||||
@@ -25,10 +25,10 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
|
||||
}
|
||||
|
||||
// Create creates a new audit log entry in the database
|
||||
func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData) model.AuditLog {
|
||||
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
|
||||
country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get IP location: %v\n", err)
|
||||
log.Printf("Failed to get IP location: %v", err)
|
||||
}
|
||||
|
||||
auditLog := model.AuditLog{
|
||||
@@ -42,8 +42,12 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
|
||||
}
|
||||
|
||||
// Save the audit log in the database
|
||||
if err := s.db.Create(&auditLog).Error; err != nil {
|
||||
log.Printf("Failed to create audit log: %v\n", err)
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&auditLog).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to create audit log: %v", err)
|
||||
return model.AuditLog{}
|
||||
}
|
||||
|
||||
@@ -51,12 +55,17 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
|
||||
}
|
||||
|
||||
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
|
||||
func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID string) model.AuditLog {
|
||||
createdAuditLog := s.Create(model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{})
|
||||
func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
|
||||
createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
|
||||
|
||||
// Count the number of times the user has logged in from the same device
|
||||
var count int64
|
||||
err := s.db.Model(&model.AuditLog{}).Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).Count(&count).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.AuditLog{}).
|
||||
Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).
|
||||
Count(&count).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to count audit logs: %v\n", err)
|
||||
return createdAuditLog
|
||||
@@ -64,11 +73,23 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
|
||||
|
||||
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
|
||||
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
|
||||
// We use a background context here as this is running in a goroutine
|
||||
//nolint:contextcheck
|
||||
go func() {
|
||||
var user model.User
|
||||
s.db.Where("id = ?", userID).First(&user)
|
||||
innerCtx := context.Background()
|
||||
|
||||
err := SendEmail(s.emailService, email.Address{
|
||||
// Note we don't use the transaction here because this is running in background
|
||||
var user model.User
|
||||
innerErr := s.db.
|
||||
WithContext(innerCtx).
|
||||
Where("id = ?", userID).
|
||||
First(&user).
|
||||
Error
|
||||
if innerErr != nil {
|
||||
log.Printf("Failed to load user: %v", innerErr)
|
||||
}
|
||||
|
||||
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
|
||||
Name: user.Username,
|
||||
Email: user.Email,
|
||||
}, NewLoginTemplate, &NewLoginTemplateData{
|
||||
@@ -78,8 +99,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
|
||||
Device: s.DeviceStringFromUserAgent(userAgent),
|
||||
DateTime: createdAuditLog.CreatedAt.UTC(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
|
||||
if innerErr != nil {
|
||||
log.Printf("Failed to send email to '%s': %v", user.Email, innerErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -88,9 +109,12 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
|
||||
}
|
||||
|
||||
// ListAuditLogsForUser retrieves all audit logs for a given user ID
|
||||
func (s *AuditLogService) ListAuditLogsForUser(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
|
||||
func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
|
||||
var logs []model.AuditLog
|
||||
query := s.db.Model(&model.AuditLog{}).Where("user_id = ?", userID)
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.AuditLog{}).
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
|
||||
return logs, pagination, err
|
||||
@@ -162,19 +186,19 @@ func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[s
|
||||
}
|
||||
|
||||
func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []string, err error) {
|
||||
dialect := s.db.Name()
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.AuditLog{})
|
||||
|
||||
dialect := s.db.Name()
|
||||
switch dialect {
|
||||
case "sqlite":
|
||||
query = query.
|
||||
Select("DISTINCT json_extract(data, '$.clientName') as client_name").
|
||||
Select("DISTINCT json_extract(data, '$.clientName') AS client_name").
|
||||
Where("json_extract(data, '$.clientName') IS NOT NULL")
|
||||
case "postgres":
|
||||
query = query.
|
||||
Select("DISTINCT data->>'clientName' as client_name").
|
||||
Select("DISTINCT data->>'clientName' AS client_name").
|
||||
Where("data->>'clientName' IS NOT NULL")
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database dialect: %s", dialect)
|
||||
|
||||
@@ -1,34 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Reserved claims
|
||||
var reservedClaims = map[string]struct{}{
|
||||
"given_name": {},
|
||||
"family_name": {},
|
||||
"name": {},
|
||||
"email": {},
|
||||
"preferred_username": {},
|
||||
"groups": {},
|
||||
"sub": {},
|
||||
"iss": {},
|
||||
"aud": {},
|
||||
"exp": {},
|
||||
"iat": {},
|
||||
"auth_time": {},
|
||||
"nonce": {},
|
||||
"acr": {},
|
||||
"amr": {},
|
||||
"azp": {},
|
||||
"nbf": {},
|
||||
"jti": {},
|
||||
}
|
||||
|
||||
type CustomClaimService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
@@ -39,8 +19,29 @@ func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
|
||||
|
||||
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
|
||||
func isReservedClaim(key string) bool {
|
||||
_, ok := reservedClaims[key]
|
||||
return ok
|
||||
switch key {
|
||||
case "given_name",
|
||||
"family_name",
|
||||
"name",
|
||||
"email",
|
||||
"preferred_username",
|
||||
"groups",
|
||||
"sub",
|
||||
"iss",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"auth_time",
|
||||
"nonce",
|
||||
"acr",
|
||||
"amr",
|
||||
"azp",
|
||||
"nbf",
|
||||
"jti":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// idType is the type of the id used to identify the user or user group
|
||||
@@ -52,28 +53,38 @@ const (
|
||||
)
|
||||
|
||||
// UpdateCustomClaimsForUser updates the custom claims for a user
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(UserID, userID, claims)
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(ctx, UserID, userID, claims)
|
||||
}
|
||||
|
||||
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(UserGroupID, userGroupID, claims)
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(ctx, UserGroupID, userGroupID, claims)
|
||||
}
|
||||
|
||||
// updateCustomClaims updates the custom claims for a user or user group
|
||||
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) updateCustomClaims(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
// Check for duplicate keys in the claims slice
|
||||
seenKeys := make(map[string]bool)
|
||||
seenKeys := make(map[string]struct{})
|
||||
for _, claim := range claims {
|
||||
if seenKeys[claim.Key] {
|
||||
if _, ok := seenKeys[claim.Key]; ok {
|
||||
return nil, &common.DuplicateClaimError{Key: claim.Key}
|
||||
}
|
||||
seenKeys[claim.Key] = true
|
||||
seenKeys[claim.Key] = struct{}{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var existingClaims []model.CustomClaim
|
||||
err := s.db.Where(string(idType), value).Find(&existingClaims).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where(string(idType), value).
|
||||
Find(&existingClaims).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -87,8 +98,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
err = s.db.Delete(&existingClaim).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&existingClaim).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,7 +128,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
}
|
||||
|
||||
// Update the claim if it already exists or create a new one
|
||||
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where(string(idType)+" = ? AND key = ?", value, claim.Key).
|
||||
Assign(&customClaim).
|
||||
FirstOrCreate(&model.CustomClaim{}).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +141,16 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
|
||||
// Get the updated claims
|
||||
var updatedClaims []model.CustomClaim
|
||||
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where(string(idType)+" = ?", value).
|
||||
Find(&updatedClaims).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -129,23 +158,31 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
return updatedClaims, nil
|
||||
}
|
||||
|
||||
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
||||
var customClaims []model.CustomClaim
|
||||
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Find(&customClaims).
|
||||
Error
|
||||
return customClaims, err
|
||||
}
|
||||
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
||||
var customClaims []model.CustomClaim
|
||||
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("user_group_id = ?", userGroupID).
|
||||
Find(&customClaims).
|
||||
Error
|
||||
return customClaims, err
|
||||
}
|
||||
|
||||
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
|
||||
// prioritizing the user's claims over user group claims with the same key.
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
||||
// Get the custom claims of the user
|
||||
customClaims, err := s.GetCustomClaimsForUser(userID)
|
||||
customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,7 +195,9 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
|
||||
|
||||
// Get all user groups of the user
|
||||
var userGroupsOfUser []model.UserGroup
|
||||
err = s.db.Preload("CustomClaims").
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("CustomClaims").
|
||||
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
|
||||
Where("user_groups_users.user_id = ?", userID).
|
||||
Find(&userGroupsOfUser).Error
|
||||
@@ -186,10 +225,12 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
|
||||
}
|
||||
|
||||
// GetSuggestions returns a list of custom claim keys that have been used before
|
||||
func (s *CustomClaimService) GetSuggestions() ([]string, error) {
|
||||
func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) {
|
||||
var customClaimsKeys []string
|
||||
|
||||
err := s.db.Model(&model.CustomClaim{}).
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.CustomClaim{}).
|
||||
Group("key").
|
||||
Order("COUNT(*) DESC").
|
||||
Pluck("key", &customClaimsKeys).Error
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
@@ -34,6 +35,7 @@ func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService
|
||||
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService}
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *TestService) SeedDatabase() error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
users := []model.User{
|
||||
@@ -187,11 +189,8 @@ func (s *TestService) SeedDatabase() error {
|
||||
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
|
||||
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
|
||||
|
||||
publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicKeyPasskey1, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKeyPasskey2, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
|
||||
webauthnCredentials := []model.WebauthnCredential{
|
||||
{
|
||||
Name: "Passkey 1",
|
||||
@@ -303,7 +302,7 @@ func (s *TestService) ResetApplicationImages() error {
|
||||
|
||||
func (s *TestService) ResetAppConfig() error {
|
||||
// Reseed the config variables
|
||||
if err := s.appConfigService.InitDbConfig(); err != nil {
|
||||
if err := s.appConfigService.InitDbConfig(context.Background()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -320,7 +319,7 @@ func (s *TestService) SetJWTKeys() {
|
||||
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
|
||||
|
||||
privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
|
||||
s.jwtService.SetKey(privateKey)
|
||||
_ = s.jwtService.SetKey(privateKey)
|
||||
}
|
||||
|
||||
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
||||
|
||||
@@ -2,10 +2,12 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
htemplate "html/template"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"mime/quotedprintable"
|
||||
"net/textproto"
|
||||
@@ -17,10 +19,11 @@ import (
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type EmailService struct {
|
||||
@@ -49,20 +52,24 @@ func NewEmailService(appConfigService *AppConfigService, db *gorm.DB) (*EmailSer
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (srv *EmailService) SendTestEmail(recipientUserId string) error {
|
||||
func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId string) error {
|
||||
var user model.User
|
||||
if err := srv.db.First(&user, "id = ?", recipientUserId).Error; err != nil {
|
||||
err := srv.db.
|
||||
WithContext(ctx).
|
||||
First(&user, "id = ?", recipientUserId).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return SendEmail(srv,
|
||||
return SendEmail(ctx, srv,
|
||||
email.Address{
|
||||
Email: user.Email,
|
||||
Name: user.FullName(),
|
||||
}, TestTemplate, nil)
|
||||
}
|
||||
|
||||
func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
|
||||
func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
|
||||
data := &email.TemplateData[V]{
|
||||
AppName: srv.appConfigService.DbConfig.AppName.Value,
|
||||
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
|
||||
@@ -112,6 +119,15 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
|
||||
|
||||
c.Body(body)
|
||||
|
||||
// Check if the context is still valid before attemtping to connect
|
||||
// We need to do this because the smtp library doesn't have context support
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
// All good
|
||||
}
|
||||
|
||||
// Connect to the SMTP server
|
||||
client, err := srv.getSmtpClient()
|
||||
if err != nil {
|
||||
@@ -119,6 +135,14 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Check if the context is still valid before sending the email
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
// All good
|
||||
}
|
||||
|
||||
// Send the email
|
||||
if err := srv.sendEmailContent(client, toEmail, c); err != nil {
|
||||
return fmt.Errorf("send email content: %w", err)
|
||||
@@ -215,7 +239,7 @@ func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Add
|
||||
}
|
||||
|
||||
// Write the email content
|
||||
_, err = w.Write([]byte(c.String()))
|
||||
_, err = io.Copy(w, strings.NewReader(c.String()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write email data: %w", err)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ var tailscaleIPNets = []*net.IPNet{
|
||||
}
|
||||
|
||||
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
|
||||
func NewGeoLiteService() *GeoLiteService {
|
||||
func NewGeoLiteService(ctx context.Context) *GeoLiteService {
|
||||
service := &GeoLiteService{}
|
||||
|
||||
if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl {
|
||||
@@ -52,8 +52,9 @@ func NewGeoLiteService() *GeoLiteService {
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := service.updateDatabase(); err != nil {
|
||||
log.Printf("Failed to update GeoLite2 City database: %v\n", err)
|
||||
err := service.updateDatabase(ctx)
|
||||
if err != nil {
|
||||
log.Printf("Failed to update GeoLite2 City database: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -111,7 +112,7 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
|
||||
}
|
||||
|
||||
// UpdateDatabase checks the age of the database and updates it if it's older than 14 days.
|
||||
func (s *GeoLiteService) updateDatabase() error {
|
||||
func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error {
|
||||
if s.disableUpdater {
|
||||
// Avoid updating the GeoLite2 City database.
|
||||
return nil
|
||||
@@ -125,7 +126,7 @@ func (s *GeoLiteService) updateDatabase() error {
|
||||
log.Println("Updating GeoLite2 City database...")
|
||||
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
|
||||
|
||||
@@ -38,7 +38,9 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
// Setup LDAP connection
|
||||
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
|
||||
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
|
||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
|
||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
|
||||
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
||||
}
|
||||
@@ -53,22 +55,31 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) SyncAll() error {
|
||||
err := s.SyncUsers()
|
||||
func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
// Start a transaction
|
||||
tx := s.db.Begin()
|
||||
|
||||
err := s.SyncUsers(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users: %w", err)
|
||||
}
|
||||
|
||||
err = s.SyncGroups()
|
||||
err = s.SyncGroups(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync groups: %w", err)
|
||||
}
|
||||
|
||||
// Commit the changes
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes to database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncGroups() error {
|
||||
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
@@ -112,7 +123,7 @@ func (s *LdapService) SyncGroups() error {
|
||||
|
||||
// Try to find the group in the database
|
||||
var databaseGroup model.UserGroup
|
||||
s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup)
|
||||
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
|
||||
|
||||
// Get group members and add to the correct Group
|
||||
groupMembers := value.GetAttributeValues(groupMemberOfAttribute)
|
||||
@@ -122,7 +133,7 @@ func (s *LdapService) SyncGroups() error {
|
||||
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
|
||||
|
||||
var databaseUser model.User
|
||||
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
|
||||
err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// The user collides with a non-LDAP user, so we skip it
|
||||
@@ -143,39 +154,51 @@ func (s *LdapService) SyncGroups() error {
|
||||
}
|
||||
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.Create(syncGroup)
|
||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
} else {
|
||||
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
}
|
||||
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
return err
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get all LDAP groups from the database
|
||||
var ldapGroupsInDb []model.UserGroup
|
||||
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
|
||||
fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err))
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to fetch groups from database: %v", err)
|
||||
}
|
||||
|
||||
// Delete groups that no longer exist in LDAP
|
||||
for _, group := range ldapGroupsInDb {
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; !exists {
|
||||
if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to delete group %s with: %v", group.Name, err)
|
||||
} else {
|
||||
log.Printf("Deleted group %s", group.Name)
|
||||
@@ -187,7 +210,7 @@ func (s *LdapService) SyncGroups() error {
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncUsers() error {
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
@@ -241,7 +264,7 @@ func (s *LdapService) SyncUsers() error {
|
||||
|
||||
// Get the user from the database
|
||||
var databaseUser model.User
|
||||
s.db.Where("ldap_id = ?", ldapId).First(&databaseUser)
|
||||
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser)
|
||||
|
||||
// Check if user is admin by checking if they are in the admin group
|
||||
isAdmin := false
|
||||
@@ -261,68 +284,75 @@ func (s *LdapService) SyncUsers() error {
|
||||
}
|
||||
|
||||
if databaseUser.ID == "" {
|
||||
_, err = s.userService.CreateUser(newUser)
|
||||
_, err = s.userService.createUserInternal(ctx, newUser, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing user %s: %s", newUser.Username, err)
|
||||
log.Printf("Error syncing user %s: %v", newUser.Username, err)
|
||||
}
|
||||
} else {
|
||||
_, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true)
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing user %s: %s", newUser.Username, err)
|
||||
log.Printf("Error syncing user %s: %v", newUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save profile picture
|
||||
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" {
|
||||
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil {
|
||||
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err)
|
||||
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
|
||||
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get all LDAP users from the database
|
||||
var ldapUsersInDb []model.User
|
||||
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
|
||||
fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err))
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to fetch users from database: %v", err)
|
||||
}
|
||||
|
||||
// Delete users that no longer exist in LDAP
|
||||
for _, user := range ldapUsersInDb {
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
|
||||
if err := s.userService.DeleteUser(user.ID, true); err != nil {
|
||||
if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil {
|
||||
log.Printf("Failed to delete user %s with: %v", user.Username, err)
|
||||
} else {
|
||||
log.Printf("Deleted user %s", user.Username)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error {
|
||||
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
||||
var reader io.Reader
|
||||
|
||||
if _, err := url.ParseRequestURI(pictureString); err == nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_, err := url.ParseRequestURI(pictureString)
|
||||
if err == nil {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
var res *http.Response
|
||||
res, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download profile picture: %w", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
reader = response.Body
|
||||
defer res.Body.Close()
|
||||
|
||||
reader = res.Body
|
||||
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
|
||||
// If the photo is a base64 encoded string, decode it
|
||||
reader = bytes.NewReader(decodedPhoto)
|
||||
|
||||
} else {
|
||||
// If the photo is a string, we assume that it's a binary string
|
||||
reader = bytes.NewReader([]byte(pictureString))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -39,9 +41,20 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppCo
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.Preload("AllowedUserGroups").First(&client, "id = ?", input.ClientID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("AllowedUserGroups").
|
||||
First(&client, "id = ?", input.ClientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -58,7 +71,12 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
|
||||
|
||||
// Check if the user group is allowed to authorize the client
|
||||
var user model.User
|
||||
if err := s.db.Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("UserGroups").
|
||||
First(&user, "id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -67,7 +85,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
|
||||
}
|
||||
|
||||
// Check if the user has already authorized the client with the given scope
|
||||
hasAuthorizedClient, err := s.HasAuthorizedClient(input.ClientID, userID, input.Scope)
|
||||
hasAuthorizedClient, err := s.hasAuthorizedClientInternal(ctx, input.ClientID, userID, input.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -80,39 +98,55 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
|
||||
Scope: input.Scope,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&userAuthorizedClient).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
// The client has already been authorized but with a different scope so we need to update the scope
|
||||
if err := s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
|
||||
if err := tx.
|
||||
WithContext(ctx).
|
||||
Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
} else {
|
||||
} else if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the authorization code
|
||||
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
|
||||
code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Log the authorization event
|
||||
if hasAuthorizedClient {
|
||||
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
|
||||
} else {
|
||||
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return code, callbackURL, nil
|
||||
}
|
||||
|
||||
// HasAuthorizedClient checks if the user has already authorized the client with the given scope
|
||||
func (s *OidcService) HasAuthorizedClient(clientID, userID, scope string) (bool, error) {
|
||||
func (s *OidcService) HasAuthorizedClient(ctx context.Context, clientID, userID, scope string) (bool, error) {
|
||||
return s.hasAuthorizedClientInternal(ctx, clientID, userID, scope, s.db)
|
||||
}
|
||||
|
||||
func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID, userID, scope string, tx *gorm.DB) (bool, error) {
|
||||
var userAuthorizedOidcClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
@@ -145,21 +179,31 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
|
||||
return isAllowedToAuthorize
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
|
||||
func (s *OidcService) CreateTokens(ctx context.Context, code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
|
||||
switch grantType {
|
||||
case "authorization_code":
|
||||
return s.createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier)
|
||||
return s.createTokenFromAuthorizationCode(ctx, code, clientID, clientSecret, codeVerifier)
|
||||
case "refresh_token":
|
||||
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(refreshToken, clientID, clientSecret)
|
||||
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, refreshToken, clientID, clientSecret)
|
||||
return "", accessToken, newRefreshToken, exp, err
|
||||
default:
|
||||
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
||||
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
@@ -176,7 +220,11 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
First(&authorizationCodeMetaData, "code = ?", code).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
@@ -192,7 +240,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
|
||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
@@ -203,7 +251,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
|
||||
}
|
||||
|
||||
// Generate a refresh token
|
||||
refreshToken, err = s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
|
||||
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
@@ -213,19 +261,40 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
s.db.Delete(&authorizationCodeMetaData)
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&authorizationCodeMetaData).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
return idToken, accessToken, refreshToken, 3600, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
|
||||
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
|
||||
if refreshToken == "" {
|
||||
return "", "", 0, &common.OidcMissingRefreshTokenError{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// Get the client to check if it's public
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
@@ -243,7 +312,9 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client
|
||||
|
||||
// Verify refresh token
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = s.db.Preload("User").
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
@@ -266,29 +337,53 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client
|
||||
}
|
||||
|
||||
// Generate a new refresh token and invalidate the old one
|
||||
newRefreshToken, err = s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
|
||||
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
// Delete the used refresh token
|
||||
s.db.Delete(&storedRefreshToken)
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
return accessToken, newRefreshToken, 3600, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
|
||||
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
|
||||
return s.getClientInternal(ctx, clientID, s.db)
|
||||
}
|
||||
|
||||
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.Preload("CreatedBy").Preload("AllowedUserGroups").First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
Preload("AllowedUserGroups").
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
|
||||
func (s *OidcService) ListClients(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
|
||||
var clients []model.OidcClient
|
||||
|
||||
query := s.db.Preload("CreatedBy").Model(&model.OidcClient{})
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
Model(&model.OidcClient{})
|
||||
if searchTerm != "" {
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
query = query.Where("name LIKE ?", searchPattern)
|
||||
@@ -302,7 +397,7 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti
|
||||
return clients, pagination, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
@@ -312,16 +407,31 @@ func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string)
|
||||
PkceEnabled: input.IsPublic || input.PkceEnabled,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&client).Error; err != nil {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Create(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
|
||||
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
@@ -331,29 +441,49 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
|
||||
client.IsPublic = input.IsPublic
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClient(clientID string) error {
|
||||
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.Delete(&client).Error; err != nil {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", clientID).
|
||||
Delete(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
|
||||
func (s *OidcService) CreateClientSecret(ctx context.Context, clientID string) (string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -368,16 +498,29 @@ func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
|
||||
}
|
||||
|
||||
client.Secret = string(hashedSecret)
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
|
||||
func (s *OidcService) GetClientLogo(ctx context.Context, clientID string) (string, string, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -385,26 +528,36 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
|
||||
return "", "", errors.New("image not found")
|
||||
}
|
||||
|
||||
imageType := *client.ImageType
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
|
||||
mimeType := utils.GetImageMimeType(imageType)
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
|
||||
mimeType := utils.GetImageMimeType(*client.ImageType)
|
||||
|
||||
return imagePath, mimeType, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
|
||||
func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader) error {
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
||||
return &common.FileTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
|
||||
if err := utils.SaveFile(file, imagePath); err != nil {
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + fileType
|
||||
err := utils.SaveFile(file, imagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -416,16 +569,35 @@ func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHead
|
||||
}
|
||||
|
||||
client.ImageType = &fileType
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientLogo(clientID string) error {
|
||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -433,38 +605,72 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {
|
||||
return errors.New("image not found")
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||
client.ImageType = nil
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
|
||||
if err := os.Remove(imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.ImageType = nil
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.Preload("User.UserGroups").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("User.UserGroups").
|
||||
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := authorizedOidcClient.User
|
||||
scope := authorizedOidcClient.Scope
|
||||
scopes := strings.Split(authorizedOidcClient.Scope, " ")
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "email") {
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue()
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "groups") {
|
||||
if slices.Contains(scopes, "groups") {
|
||||
userGroups := make([]string, len(user.UserGroups))
|
||||
for i, group := range user.UserGroups {
|
||||
userGroups[i] = group.Name
|
||||
@@ -477,17 +683,17 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
|
||||
"family_name": user.LastName,
|
||||
"name": user.FullName(),
|
||||
"preferred_username": user.Username,
|
||||
"picture": fmt.Sprintf("%s/api/users/%s/profile-picture.png", common.EnvConfig.AppURL, user.ID),
|
||||
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "profile") {
|
||||
if slices.Contains(scopes, "profile") {
|
||||
// Add profile claims
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID)
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -505,15 +711,22 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(scope, "email") {
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
|
||||
client, err = s.GetClient(id)
|
||||
func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err = s.getClientInternal(ctx, id, tx)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
@@ -521,18 +734,37 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
|
||||
// Fetch the user groups based on UserGroupIDs in input
|
||||
var groups []model.UserGroup
|
||||
if len(input.UserGroupIDs) > 0 {
|
||||
if err := s.db.Where("id IN (?)", input.UserGroupIDs).Find(&groups).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("id IN (?)", input.UserGroupIDs).
|
||||
Find(&groups).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the current user groups with the new set of user groups
|
||||
if err := s.db.Model(&client).Association("AllowedUserGroups").Replace(groups); err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&client).
|
||||
Association("AllowedUserGroups").
|
||||
Replace(groups)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
// Save the updated client
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
@@ -540,7 +772,7 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
|
||||
}
|
||||
|
||||
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
|
||||
func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) {
|
||||
func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogoutDto, userID string) (string, error) {
|
||||
// If no ID token hint is provided, return an error
|
||||
if input.IdTokenHint == "" {
|
||||
return "", &common.TokenInvalidError{}
|
||||
@@ -564,7 +796,12 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
|
||||
|
||||
// Check if the user has authorized the client before
|
||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).Error; err != nil {
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Preload("Client").
|
||||
First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", &common.OidcMissingAuthorizationError{}
|
||||
}
|
||||
|
||||
@@ -582,7 +819,7 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
|
||||
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
|
||||
func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -601,7 +838,11 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
|
||||
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&oidcAuthorizationCode).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -647,7 +888,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
|
||||
return "", &common.OidcInvalidCallbackURLError{}
|
||||
}
|
||||
|
||||
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
|
||||
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
|
||||
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -665,7 +906,11 @@ func (s *OidcService) createRefreshToken(clientID string, userID string, scope s
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&m).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&m).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserGroupService struct {
|
||||
@@ -19,8 +21,11 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG
|
||||
return &UserGroupService{db: db, appConfigService: appConfigService}
|
||||
}
|
||||
|
||||
func (s *UserGroupService) List(name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
|
||||
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{})
|
||||
func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Preload("CustomClaims").
|
||||
Model(&model.UserGroup{})
|
||||
|
||||
if name != "" {
|
||||
query = query.Where("name LIKE ?", "%"+name+"%")
|
||||
@@ -42,26 +47,59 @@ func (s *UserGroupService) List(name string, sortedPaginationRequest utils.Sorte
|
||||
return groups, response, err
|
||||
}
|
||||
|
||||
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) {
|
||||
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error
|
||||
func (s *UserGroupService) Get(ctx context.Context, id string) (group model.UserGroup, err error) {
|
||||
return s.getInternal(ctx, id, s.db)
|
||||
}
|
||||
|
||||
func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.DB) (group model.UserGroup, err error) {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", id).
|
||||
Preload("CustomClaims").
|
||||
Preload("Users").
|
||||
First(&group).
|
||||
Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
func (s *UserGroupService) Delete(id string) error {
|
||||
func (s *UserGroupService) Delete(ctx context.Context, id string) error {
|
||||
tx := s.db.Begin()
|
||||
|
||||
var group model.UserGroup
|
||||
if err := s.db.Where("id = ?", id).First(&group).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", id).
|
||||
First(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
|
||||
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
|
||||
err = tx.Rollback().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return &common.LdapUserGroupUpdateError{}
|
||||
}
|
||||
|
||||
return s.db.Delete(&group).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
|
||||
return s.createInternal(ctx, input, s.db)
|
||||
}
|
||||
|
||||
func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGroupCreateDto, tx *gorm.DB) (group model.UserGroup, err error) {
|
||||
group = model.UserGroup{
|
||||
FriendlyName: input.FriendlyName,
|
||||
Name: input.Name,
|
||||
@@ -71,7 +109,12 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
|
||||
group.LdapID = &input.LdapID
|
||||
}
|
||||
|
||||
if err := s.db.Preload("Users").Create(&group).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("Users").
|
||||
Create(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
|
||||
}
|
||||
@@ -80,8 +123,26 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
|
||||
group, err = s.Get(id)
|
||||
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
|
||||
tx := s.db.Begin()
|
||||
|
||||
group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) {
|
||||
group, err = s.getInternal(ctx, id, tx)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
@@ -94,7 +155,12 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
|
||||
group.Name = input.Name
|
||||
group.FriendlyName = input.FriendlyName
|
||||
|
||||
if err := s.db.Preload("Users").Save(&group).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("Users").
|
||||
Save(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
|
||||
}
|
||||
@@ -103,8 +169,26 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) {
|
||||
group, err = s.Get(id)
|
||||
func (s *UserGroupService) UpdateUsers(ctx context.Context, id string, userIds []string) (group model.UserGroup, err error) {
|
||||
tx := s.db.Begin()
|
||||
|
||||
group, err = s.updateUsersInternal(ctx, id, userIds, tx)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, userIds []string, tx *gorm.DB) (group model.UserGroup, err error) {
|
||||
group, err = s.getInternal(ctx, id, tx)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
@@ -112,28 +196,59 @@ func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model
|
||||
// Fetch the users based on the userIds
|
||||
var users []model.User
|
||||
if len(userIds) > 0 {
|
||||
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id IN (?)", userIds).
|
||||
Find(&users).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the current users with the new set of users
|
||||
if err := s.db.Model(&group).Association("Users").Replace(users); err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&group).
|
||||
Association("Users").
|
||||
Replace(users)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
// Save the updated group
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) {
|
||||
func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (int64, error) {
|
||||
// We only perform select queries here, so we can rollback in all cases
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var group model.UserGroup
|
||||
if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("Users").
|
||||
Where("id = ?", id).
|
||||
First(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.db.Model(&group).Association("Users").Count(), nil
|
||||
count := tx.
|
||||
WithContext(ctx).
|
||||
Model(&group).
|
||||
Association("Users").
|
||||
Count()
|
||||
return count, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
@@ -20,7 +21,7 @@ import (
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
||||
"gorm.io/gorm"
|
||||
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
@@ -35,9 +36,9 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
|
||||
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService}
|
||||
}
|
||||
|
||||
func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
|
||||
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
|
||||
var users []model.User
|
||||
query := s.db.Model(&model.User{})
|
||||
query := s.db.WithContext(ctx).Model(&model.User{})
|
||||
|
||||
if searchTerm != "" {
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
@@ -48,13 +49,23 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils
|
||||
return users, pagination, err
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(userID string) (model.User, error) {
|
||||
func (s *UserService) GetUser(ctx context.Context, userID string) (model.User, error) {
|
||||
return s.getUserInternal(ctx, userID, s.db)
|
||||
}
|
||||
|
||||
func (s *UserService) getUserInternal(ctx context.Context, userID string, tx *gorm.DB) (model.User, error) {
|
||||
var user model.User
|
||||
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("UserGroups").
|
||||
Preload("CustomClaims").
|
||||
Where("id = ?", userID).
|
||||
First(&user).
|
||||
Error
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, error) {
|
||||
func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.ReadCloser, int64, error) {
|
||||
// Validate the user ID to prevent directory traversal
|
||||
if err := uuid.Validate(userID); err != nil {
|
||||
return nil, 0, &common.InvalidUUIDError{}
|
||||
@@ -74,7 +85,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er
|
||||
}
|
||||
|
||||
// If no custom picture exists, get the user's data for creating initials
|
||||
user, err := s.GetUser(userID)
|
||||
user, err := s.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -115,9 +126,15 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er
|
||||
return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(defaultPicture.Len()), nil
|
||||
}
|
||||
|
||||
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
|
||||
func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model.UserGroup, error) {
|
||||
var user model.User
|
||||
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Preload("UserGroups").
|
||||
Where("id = ?", userID).
|
||||
First(&user).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.UserGroups, nil
|
||||
@@ -152,9 +169,21 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error {
|
||||
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
|
||||
var user model.User
|
||||
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", userID).
|
||||
First(&user).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -165,14 +194,35 @@ func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error {
|
||||
|
||||
// Delete the profile picture
|
||||
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
|
||||
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) {
|
||||
err = os.Remove(profilePicturePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Delete(&user).Error
|
||||
return tx.WithContext(ctx).Delete(&user).Error
|
||||
}
|
||||
|
||||
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
|
||||
func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
user, err := s.createUserInternal(ctx, input, tx)
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, tx *gorm.DB) (model.User, error) {
|
||||
user := model.User{
|
||||
FirstName: input.FirstName,
|
||||
LastName: input.LastName,
|
||||
@@ -185,18 +235,47 @@ func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
|
||||
user.LdapID = &input.LdapID
|
||||
}
|
||||
|
||||
if err := s.db.Create(&user).Error; err != nil {
|
||||
err := tx.WithContext(ctx).Create(&user).Error
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return model.User{}, s.checkDuplicatedFields(user)
|
||||
}
|
||||
tx.Rollback()
|
||||
|
||||
// If we are here, the transaction is already aborted due to an error, so we pass s.db
|
||||
err = s.checkDuplicatedFields(ctx, user, s.db)
|
||||
return model.User{}, err
|
||||
} else if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
|
||||
func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
user, err := s.updateUserInternal(ctx, userID, updatedUser, updateOwnUser, allowLdapUpdate, tx)
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool, tx *gorm.DB) (model.User, error) {
|
||||
var user model.User
|
||||
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", userID).
|
||||
First(&user).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
@@ -214,24 +293,42 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
|
||||
user.IsAdmin = updatedUser.IsAdmin
|
||||
}
|
||||
|
||||
if err := s.db.Save(&user).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&user).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return user, s.checkDuplicatedFields(user)
|
||||
}
|
||||
tx.Rollback()
|
||||
|
||||
// If we are here, the transaction is already aborted due to an error, so we pass s.db
|
||||
err = s.checkDuplicatedFields(ctx, user, s.db)
|
||||
return user, err
|
||||
} else if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error {
|
||||
func (s *UserService) RequestOneTimeAccessEmail(ctx context.Context, emailAddress, redirectPath string) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue()
|
||||
if isDisabled {
|
||||
return &common.OneTimeAccessDisabledError{}
|
||||
}
|
||||
|
||||
var user model.User
|
||||
if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("email = ?", emailAddress).
|
||||
First(&user).
|
||||
Error
|
||||
if err != nil {
|
||||
// Do not return error if user not found to prevent email enumeration
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil
|
||||
@@ -240,22 +337,31 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
|
||||
}
|
||||
}
|
||||
|
||||
oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(15*time.Minute))
|
||||
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, time.Now().Add(15*time.Minute), tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
link := fmt.Sprintf("%s/lc", common.EnvConfig.AppURL)
|
||||
linkWithCode := fmt.Sprintf("%s/%s", link, oneTimeAccessToken)
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We use a background context here as this is running in a goroutine
|
||||
//nolint:contextcheck
|
||||
go func() {
|
||||
innerCtx := context.Background()
|
||||
|
||||
link := common.EnvConfig.AppURL + "/lc"
|
||||
linkWithCode := link + "/" + oneTimeAccessToken
|
||||
|
||||
// Add redirect path to the link
|
||||
if strings.HasPrefix(redirectPath, "/") {
|
||||
encodedRedirectPath := url.QueryEscape(redirectPath)
|
||||
linkWithCode = fmt.Sprintf("%s?redirect=%s", linkWithCode, encodedRedirectPath)
|
||||
linkWithCode = linkWithCode + "?redirect=" + encodedRedirectPath
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := SendEmail(s.emailService, email.Address{
|
||||
errInternal := SendEmail(innerCtx, s.emailService, email.Address{
|
||||
Name: user.Username,
|
||||
Email: user.Email,
|
||||
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
|
||||
@@ -263,18 +369,21 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
|
||||
LoginLink: link,
|
||||
LoginLinkWithCode: linkWithCode,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
|
||||
if errInternal != nil {
|
||||
log.Printf("Failed to send email to '%s': %v\n", user.Email, errInternal)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
|
||||
tokenLength := 16
|
||||
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) {
|
||||
return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db)
|
||||
}
|
||||
|
||||
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) {
|
||||
// If expires at is less than 15 minutes, use an 6 character token instead of 16
|
||||
tokenLength := 16
|
||||
if time.Until(expiresAt) <= 15*time.Minute {
|
||||
tokenLength = 6
|
||||
}
|
||||
@@ -290,16 +399,27 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim
|
||||
Token: randomString,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil {
|
||||
if err := tx.WithContext(ctx).Create(&oneTimeAccessToken).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return oneTimeAccessToken.Token, nil
|
||||
}
|
||||
|
||||
func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) {
|
||||
func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var oneTimeAccessToken model.OneTimeAccessToken
|
||||
if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
|
||||
First(&oneTimeAccessToken).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
||||
}
|
||||
@@ -310,19 +430,34 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&oneTimeAccessToken).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if ipAddress != "" && userAgent != "" {
|
||||
s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{})
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
return oneTimeAccessToken.User, accessToken, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) {
|
||||
user, err = s.GetUser(id)
|
||||
func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroupIds []string) (user model.User, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
user, err = s.getUserInternal(ctx, id, tx)
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
@@ -330,27 +465,49 @@ func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user m
|
||||
// Fetch the groups based on userGroupIds
|
||||
var groups []model.UserGroup
|
||||
if len(userGroupIds) > 0 {
|
||||
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("id IN (?)", userGroupIds).
|
||||
Find(&groups).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the current groups with the new set of groups
|
||||
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&user).
|
||||
Association("UserGroups").
|
||||
Replace(groups)
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
// Save the updated user
|
||||
if err := s.db.Save(&user).Error; err != nil {
|
||||
err = tx.WithContext(ctx).Save(&user).Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var userCount int64
|
||||
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
|
||||
if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
if userCount > 1 {
|
||||
@@ -365,7 +522,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
|
||||
if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
@@ -378,16 +535,39 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
func (s *UserService) checkDuplicatedFields(user model.User) error {
|
||||
var existingUser model.User
|
||||
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
|
||||
func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User, tx *gorm.DB) error {
|
||||
var result struct {
|
||||
Found bool
|
||||
}
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND email = ?) AS found`, user.ID, user.Email).
|
||||
First(&result).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.Found {
|
||||
return &common.AlreadyInUseError{Property: "email"}
|
||||
}
|
||||
|
||||
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND username = ?) AS found`, user.ID, user.Username).
|
||||
First(&result).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.Found {
|
||||
return &common.AlreadyInUseError{Property: "username"}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type WebAuthnService struct {
|
||||
@@ -43,15 +46,31 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
|
||||
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService}
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
||||
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
s.updateWebAuthnConfig()
|
||||
|
||||
var user model.User
|
||||
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("Credentials").
|
||||
Find(&user, "id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
|
||||
options, session, err := s.webAuthn.BeginRegistration(
|
||||
&user,
|
||||
webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
|
||||
webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +81,16 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&sessionToStore).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -73,9 +101,19 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
|
||||
func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var storedSession model.WebauthnSession
|
||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&storedSession, "id = ?", sessionID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
@@ -86,7 +124,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
|
||||
}
|
||||
|
||||
var user model.User
|
||||
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&user, "id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
@@ -108,7 +150,16 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
|
||||
BackupEligible: credential.Flags.BackupEligible,
|
||||
BackupState: credential.Flags.BackupState,
|
||||
}
|
||||
if err := s.db.Create(&credentialToStore).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&credentialToStore).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
@@ -125,7 +176,7 @@ func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string {
|
||||
return "New Passkey" // Default fallback
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
|
||||
func (s *WebAuthnService) BeginLogin(ctx context.Context) (*model.PublicKeyCredentialRequestOptions, error) {
|
||||
options, session, err := s.webAuthn.BeginDiscoverableLogin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -137,7 +188,11 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Create(&sessionToStore).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -148,9 +203,19 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
|
||||
func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var storedSession model.WebauthnSession
|
||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&storedSession, "id = ?", sessionID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
@@ -160,9 +225,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
|
||||
}
|
||||
|
||||
var user *model.User
|
||||
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
||||
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
|
||||
return nil, err
|
||||
_, err = s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
||||
innerErr := tx.
|
||||
WithContext(ctx).
|
||||
Preload("Credentials").
|
||||
First(&user, "id = ?", string(userHandle)).
|
||||
Error
|
||||
if innerErr != nil {
|
||||
return nil, innerErr
|
||||
}
|
||||
return user, nil
|
||||
}, session, credentialAssertionData)
|
||||
@@ -176,41 +246,70 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
s.auditLogService.CreateNewSignInWithEmail(ipAddress, userAgent, user.ID)
|
||||
s.auditLogService.CreateNewSignInWithEmail(ctx, ipAddress, userAgent, user.ID, tx)
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
return *user, token, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
|
||||
func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([]model.WebauthnCredential, error) {
|
||||
var credentials []model.WebauthnCredential
|
||||
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Find(&credentials, "user_id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
|
||||
var credential model.WebauthnCredential
|
||||
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.Delete(&credential).Error; err != nil {
|
||||
return err
|
||||
func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Where("id = ? AND user_id = ?", credentialID, userID).
|
||||
Delete(&model.WebauthnCredential{}).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
|
||||
func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credentialID, name string) (model.WebauthnCredential, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var credential model.WebauthnCredential
|
||||
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ? AND user_id = ?", credentialID, userID).
|
||||
First(&credential).
|
||||
Error
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
credential.Name = name
|
||||
|
||||
if err := s.db.Save(&credential).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&credential).
|
||||
Error
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,13 @@ import (
|
||||
|
||||
var (
|
||||
aaguidMap map[string]string
|
||||
aaguidMapOnce sync.Once
|
||||
aaguidMapOnce *sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
aaguidMapOnce = &sync.Once{}
|
||||
}
|
||||
|
||||
// FormatAAGUID converts an AAGUID byte slice to UUID string format
|
||||
func FormatAAGUID(aaguid []byte) string {
|
||||
if len(aaguid) == 0 {
|
||||
|
||||
@@ -47,8 +47,10 @@ func TestFormatAAGUID(t *testing.T) {
|
||||
func TestGetAuthenticatorName(t *testing.T) {
|
||||
// Reset the aaguidMap for testing
|
||||
originalMap := aaguidMap
|
||||
originalOnce := aaguidMapOnce
|
||||
defer func() {
|
||||
aaguidMap = originalMap
|
||||
aaguidMapOnce = originalOnce
|
||||
}()
|
||||
|
||||
// Inject a test AAGUID map
|
||||
@@ -56,7 +58,7 @@ func TestGetAuthenticatorName(t *testing.T) {
|
||||
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
|
||||
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
|
||||
}
|
||||
aaguidMapOnce = sync.Once{}
|
||||
aaguidMapOnce = &sync.Once{}
|
||||
aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file
|
||||
|
||||
tests := []struct {
|
||||
@@ -99,7 +101,7 @@ func TestGetAuthenticatorName(t *testing.T) {
|
||||
func TestLoadAAGUIDsFromFile(t *testing.T) {
|
||||
// Reset the map and once flag for clean testing
|
||||
aaguidMap = nil
|
||||
aaguidMapOnce = sync.Once{}
|
||||
aaguidMapOnce = &sync.Once{}
|
||||
|
||||
// Trigger loading of AAGUIDs by calling GetAuthenticatorName
|
||||
GetAuthenticatorName([]byte{0x01, 0x02, 0x03, 0x04})
|
||||
|
||||
@@ -4,9 +4,8 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type PaginationResponse struct {
|
||||
@@ -47,7 +46,6 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
|
||||
}
|
||||
|
||||
return Paginate(pagination.Page, pagination.Limit, query, result)
|
||||
|
||||
}
|
||||
|
||||
func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||
|
||||
5
backend/internal/utils/ptr_util.go
Normal file
5
backend/internal/utils/ptr_util.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package utils
|
||||
|
||||
func Ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
BIN
backend/main
BIN
backend/main
Binary file not shown.
Reference in New Issue
Block a user