fix: use transactions when operations involve multiple database queries (#392)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-06 06:04:08 -07:00
committed by GitHub
parent c810fec8c4
commit ec626ee797
33 changed files with 1401 additions and 501 deletions

View File

@@ -1,19 +1,23 @@
package bootstrap
import (
"context"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
func Bootstrap() {
ctx := context.TODO()
initApplicationImages()
migrateConfigDBConnstring()
db := newDatabase()
appConfigService := service.NewAppConfigService(db)
appConfigService := service.NewAppConfigService(ctx, db)
migrateKey()
initRouter(db, appConfigService)
initRouter(ctx, db, appConfigService)
}

View File

@@ -1,6 +1,7 @@
package bootstrap
import (
"context"
"log"
"net"
"time"
@@ -19,7 +20,7 @@ import (
// This is used to register additional controllers for tests
var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService)
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
func initRouter(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) {
// Set the appropriate Gin mode based on the environment
switch common.EnvConfig.AppEnv {
case "production":
@@ -36,10 +37,10 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
// Initialize services
emailService, err := service.NewEmailService(appConfigService, db)
if err != nil {
log.Fatalf("Unable to create email service: %s", err)
log.Fatalf("Unable to create email service: %v", err)
}
geoLiteService := service.NewGeoLiteService()
geoLiteService := service.NewGeoLiteService(ctx)
auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService)
jwtService := service.NewJwtService(appConfigService)
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
@@ -57,9 +58,9 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
job.RegisterLdapJobs(ldapService, appConfigService)
job.RegisterDbCleanupJobs(db)
job.RegisterFileCleanupJobs(db)
job.RegisterLdapJobs(ctx, ldapService, appConfigService)
job.RegisterDbCleanupJobs(ctx, db)
job.RegisterFileCleanupJobs(ctx, db)
// Initialize middleware for specific routes
authMiddleware := middleware.NewAuthMiddleware(apiKeyService, jwtService)

View File

@@ -17,7 +17,7 @@ type AlreadyInUseError struct {
}
func (e *AlreadyInUseError) Error() string {
return fmt.Sprintf("%s is already in use", e.Property)
return e.Property + " is already in use"
}
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }

View File

@@ -53,7 +53,7 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
return
}
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest)
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(ctx.Request.Context(), userID, sortedPaginationRequest)
if err != nil {
_ = ctx.Error(err)
return
@@ -87,7 +87,7 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
return
}
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input)
apiKey, token, err := c.apiKeyService.CreateApiKey(ctx.Request.Context(), userID, input)
if err != nil {
_ = ctx.Error(err)
return
@@ -116,7 +116,7 @@ func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID")
apiKeyID := ctx.Param("id")
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil {
if err := c.apiKeyService.RevokeApiKey(ctx.Request.Context(), userID, apiKeyID); err != nil {
_ = ctx.Error(err)
return
}

View File

@@ -1,7 +1,6 @@
package controller
import (
"fmt"
"net/http"
"strconv"
@@ -61,7 +60,7 @@ type AppConfigController struct {
// @Failure 500 {object} object "{"error": "error message"}"
// @Router /application-configuration [get]
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(false)
configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), false)
if err != nil {
_ = c.Error(err)
return
@@ -73,7 +72,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
return
}
c.JSON(200, configVariablesDto)
c.JSON(http.StatusOK, configVariablesDto)
}
// listAllAppConfigHandler godoc
@@ -86,7 +85,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
// @Security BearerAuth
// @Router /application-configuration/all [get]
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(true)
configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), true)
if err != nil {
_ = c.Error(err)
return
@@ -98,7 +97,7 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
return
}
c.JSON(200, configVariablesDto)
c.JSON(http.StatusOK, configVariablesDto)
}
// updateAppConfigHandler godoc
@@ -118,7 +117,7 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
return
}
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input)
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(c.Request.Context(), input)
if err != nil {
_ = c.Error(err)
return
@@ -253,7 +252,7 @@ func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
// getImage is a helper function to serve image files
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
imagePath := common.EnvConfig.UploadPath + "/application-images/" + name + "." + imageType
mimeType := utils.GetImageMimeType(imageType)
c.Header("Content-Type", mimeType)
@@ -268,7 +267,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
return
}
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
err = acc.appConfigService.UpdateImage(c.Request.Context(), file, imageName, oldImageType)
if err != nil {
_ = c.Error(err)
return
@@ -285,7 +284,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
// @Security BearerAuth
// @Router /api/application-configuration/sync-ldap [post]
func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
err := acc.ldapService.SyncAll()
err := acc.ldapService.SyncAll(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
@@ -304,7 +303,7 @@ func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
func (acc *AppConfigController) testEmailHandler(c *gin.Context) {
userID := c.GetString("userID")
err := acc.emailService.SendTestEmail(userID)
err := acc.emailService.SendTestEmail(c.Request.Context(), userID)
if err != nil {
_ = c.Error(err)
return

View File

@@ -42,7 +42,9 @@ type AuditLogController struct {
// @Router /api/audit-logs [get]
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
err := c.ShouldBindQuery(&sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return
}
@@ -50,7 +52,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
userID := c.GetString("userID")
// Fetch audit logs for the user
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, sortedPaginationRequest)
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(c.Request.Context(), userID, sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return

View File

@@ -41,7 +41,7 @@ type CustomClaimController struct {
// @Security BearerAuth
// @Router /api/custom-claims/suggestions [get]
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
claims, err := ccc.customClaimService.GetSuggestions()
claims, err := ccc.customClaimService.GetSuggestions(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
@@ -69,7 +69,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex
}
userId := c.Param("userId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input)
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(c.Request.Context(), userId, input)
if err != nil {
_ = c.Error(err)
return
@@ -104,7 +104,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.C
}
userGroupId := c.Param("userGroupId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userGroupId, input)
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(c.Request.Context(), userGroupId, input)
if err != nil {
_ = c.Error(err)
return

View File

@@ -69,7 +69,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
return
}
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
if err != nil {
_ = c.Error(err)
return
@@ -100,7 +100,7 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
return
}
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope)
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(c.Request.Context(), input.ClientID, c.GetString("userID"), input.Scope)
if err != nil {
_ = c.Error(err)
return
@@ -153,6 +153,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
}
idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens(
c.Request.Context(),
input.Code,
input.GrantType,
clientID,
@@ -216,7 +217,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
_ = c.Error(&common.TokenInvalidError{})
return
}
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientID[0])
claims, err := oc.oidcService.GetUserClaimsForClient(c.Request.Context(), userID, clientID[0])
if err != nil {
_ = c.Error(err)
return
@@ -254,7 +255,7 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) {
}
}
callbackURL, err := oc.oidcService.ValidateEndSession(input, c.GetString("userID"))
callbackURL, err := oc.oidcService.ValidateEndSession(c.Request.Context(), input, c.GetString("userID"))
if err != nil {
// If the validation fails, the user has to confirm the logout manually and doesn't get redirected
log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err)
@@ -300,7 +301,7 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
// @Router /api/oidc/clients/{id}/meta [get]
func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId)
client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
if err != nil {
_ = c.Error(err)
return
@@ -327,7 +328,7 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
// @Router /api/oidc/clients/{id} [get]
func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId)
client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
if err != nil {
_ = c.Error(err)
return
@@ -363,7 +364,7 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
return
}
clients, pagination, err := oc.oidcService.ListClients(searchTerm, sortedPaginationRequest)
clients, pagination, err := oc.oidcService.ListClients(c.Request.Context(), searchTerm, sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return
@@ -398,7 +399,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
return
}
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
client, err := oc.oidcService.CreateClient(c.Request.Context(), input, c.GetString("userID"))
if err != nil {
_ = c.Error(err)
return
@@ -422,7 +423,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
// @Security BearerAuth
// @Router /api/oidc/clients/{id} [delete]
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.oidcService.DeleteClient(c.Param("id"))
err := oc.oidcService.DeleteClient(c.Request.Context(), c.Param("id"))
if err != nil {
_ = c.Error(err)
return
@@ -449,7 +450,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
return
}
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
client, err := oc.oidcService.UpdateClient(c.Request.Context(), c.Param("id"), input)
if err != nil {
_ = c.Error(err)
return
@@ -474,7 +475,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/secret [post]
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
secret, err := oc.oidcService.CreateClientSecret(c.Request.Context(), c.Param("id"))
if err != nil {
_ = c.Error(err)
return
@@ -494,7 +495,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
// @Success 200 {file} binary "Logo image"
// @Router /api/oidc/clients/{id}/logo [get]
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"))
if err != nil {
_ = c.Error(err)
return
@@ -521,7 +522,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
return
}
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
err = oc.oidcService.UpdateClientLogo(c.Request.Context(), c.Param("id"), file)
if err != nil {
_ = c.Error(err)
return
@@ -539,7 +540,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/logo [delete]
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
err := oc.oidcService.DeleteClientLogo(c.Request.Context(), c.Param("id"))
if err != nil {
_ = c.Error(err)
return
@@ -566,7 +567,7 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
return
}
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input)
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Request.Context(), c.Param("id"), input)
if err != nil {
_ = c.Error(err)
return

View File

@@ -65,7 +65,7 @@ type UserController struct {
// @Router /api/users/{id}/groups [get]
func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
userID := c.Param("id")
groups, err := uc.userService.GetUserGroups(userID)
groups, err := uc.userService.GetUserGroups(c.Request.Context(), userID)
if err != nil {
_ = c.Error(err)
return
@@ -99,7 +99,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
return
}
users, pagination, err := uc.userService.ListUsers(searchTerm, sortedPaginationRequest)
users, pagination, err := uc.userService.ListUsers(c.Request.Context(), searchTerm, sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return
@@ -125,7 +125,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto
// @Router /api/users/{id} [get]
func (uc *UserController) getUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.Param("id"))
user, err := uc.userService.GetUser(c.Request.Context(), c.Param("id"))
if err != nil {
_ = c.Error(err)
return
@@ -147,7 +147,7 @@ func (uc *UserController) getUserHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto
// @Router /api/users/me [get]
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.GetString("userID"))
user, err := uc.userService.GetUser(c.Request.Context(), c.GetString("userID"))
if err != nil {
_ = c.Error(err)
return
@@ -170,7 +170,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
// @Success 204 "No Content"
// @Router /api/users/{id} [delete]
func (uc *UserController) deleteUserHandler(c *gin.Context) {
if err := uc.userService.DeleteUser(c.Param("id"), false); err != nil {
if err := uc.userService.DeleteUser(c.Request.Context(), c.Param("id"), false); err != nil {
_ = c.Error(err)
return
}
@@ -192,7 +192,7 @@ func (uc *UserController) createUserHandler(c *gin.Context) {
return
}
user, err := uc.userService.CreateUser(input)
user, err := uc.userService.CreateUser(c.Request.Context(), input)
if err != nil {
_ = c.Error(err)
return
@@ -245,7 +245,7 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id")
picture, size, err := uc.userService.GetProfilePicture(userID)
picture, size, err := uc.userService.GetProfilePicture(c.Request.Context(), userID)
if err != nil {
_ = c.Error(err)
return
@@ -332,7 +332,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bo
if own {
input.UserID = c.GetString("userID")
}
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, input.ExpiresAt)
if err != nil {
_ = c.Error(err)
return
@@ -364,7 +364,7 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
return
}
err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath)
err := uc.userService.RequestOneTimeAccessEmail(c.Request.Context(), input.Email, input.RedirectPath)
if err != nil {
_ = c.Error(err)
return
@@ -381,7 +381,7 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/{token} [post]
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent())
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Request.Context(), c.Param("token"), c.ClientIP(), c.Request.UserAgent())
if err != nil {
_ = c.Error(err)
return
@@ -406,7 +406,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/setup [post]
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.SetupInitialAdmin()
user, token, err := uc.userService.SetupInitialAdmin(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
@@ -439,7 +439,7 @@ func (uc *UserController) updateUserGroups(c *gin.Context) {
return
}
user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds)
user, err := uc.userService.UpdateUserGroups(c.Request.Context(), c.Param("id"), input.UserGroupIds)
if err != nil {
_ = c.Error(err)
return
@@ -469,7 +469,7 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
userID = c.Param("id")
}
user, err := uc.userService.UpdateUser(userID, input, updateOwnUser, false)
user, err := uc.userService.UpdateUser(c.Request.Context(), userID, input, updateOwnUser, false)
if err != nil {
_ = c.Error(err)
return

View File

@@ -47,6 +47,8 @@ type UserGroupController struct {
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
// @Router /api/user-groups [get]
func (ugc *UserGroupController) list(c *gin.Context) {
ctx := c.Request.Context()
searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
@@ -54,7 +56,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
return
}
groups, pagination, err := ugc.UserGroupService.List(searchTerm, sortedPaginationRequest)
groups, pagination, err := ugc.UserGroupService.List(ctx, searchTerm, sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return
@@ -68,7 +70,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
_ = c.Error(err)
return
}
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID)
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(ctx, group.ID)
if err != nil {
_ = c.Error(err)
return
@@ -93,7 +95,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
// @Security BearerAuth
// @Router /api/user-groups/{id} [get]
func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Param("id"))
group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
if err != nil {
_ = c.Error(err)
return
@@ -125,7 +127,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
return
}
group, err := ugc.UserGroupService.Create(input)
group, err := ugc.UserGroupService.Create(c.Request.Context(), input)
if err != nil {
_ = c.Error(err)
return
@@ -158,7 +160,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
return
}
group, err := ugc.UserGroupService.Update(c.Param("id"), input, false)
group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input, false)
if err != nil {
_ = c.Error(err)
return
@@ -184,7 +186,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
// @Security BearerAuth
// @Router /api/user-groups/{id} [delete]
func (ugc *UserGroupController) delete(c *gin.Context) {
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil {
if err := ugc.UserGroupService.Delete(c.Request.Context(), c.Param("id")); err != nil {
_ = c.Error(err)
return
}
@@ -210,7 +212,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
return
}
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs)
group, err := ugc.UserGroupService.UpdateUsers(c.Request.Context(), c.Param("id"), input.UserIDs)
if err != nil {
_ = c.Error(err)
return

View File

@@ -37,7 +37,7 @@ type WebauthnController struct {
func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
userID := c.GetString("userID")
options, err := wc.webAuthnService.BeginRegistration(userID)
options, err := wc.webAuthnService.BeginRegistration(c.Request.Context(), userID)
if err != nil {
_ = c.Error(err)
return
@@ -55,7 +55,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
}
userID := c.GetString("userID")
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
credential, err := wc.webAuthnService.VerifyRegistration(c.Request.Context(), sessionID, userID, c.Request)
if err != nil {
_ = c.Error(err)
return
@@ -71,7 +71,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
}
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
options, err := wc.webAuthnService.BeginLogin()
options, err := wc.webAuthnService.BeginLogin(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
@@ -94,7 +94,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
return
}
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
user, token, err := wc.webAuthnService.VerifyLogin(c.Request.Context(), sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
if err != nil {
_ = c.Error(err)
return
@@ -114,7 +114,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
userID := c.GetString("userID")
credentials, err := wc.webAuthnService.ListCredentials(userID)
credentials, err := wc.webAuthnService.ListCredentials(c.Request.Context(), userID)
if err != nil {
_ = c.Error(err)
return
@@ -133,7 +133,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
userID := c.GetString("userID")
credentialID := c.Param("id")
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
err := wc.webAuthnService.DeleteCredential(c.Request.Context(), userID, credentialID)
if err != nil {
_ = c.Error(err)
return
@@ -152,7 +152,7 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
return
}
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
credential, err := wc.webAuthnService.UpdateCredential(c.Request.Context(), userID, credentialID, input.Name)
if err != nil {
_ = c.Error(err)
return

View File

@@ -1,16 +1,18 @@
package job
import (
"context"
"log"
"time"
"github.com/go-co-op/gocron/v2"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm"
)
func RegisterDbCleanupJobs(db *gorm.DB) {
func RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
@@ -18,11 +20,11 @@ func RegisterDbCleanupJobs(db *gorm.DB) {
jobs := &DbCleanupJobs{db: db}
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
registerJob(ctx, scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(ctx, scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(ctx, scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
registerJob(ctx, scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
registerJob(ctx, scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
scheduler.Start()
}
@@ -31,26 +33,41 @@ type DbCleanupJobs struct {
}
// ClearWebauthnSessions deletes WebAuthn sessions that have expired
func (j *DbCleanupJobs) clearWebauthnSessions() error {
return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
func (j *DbCleanupJobs) clearWebauthnSessions(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOneTimeAccessTokens deletes one-time access tokens that have expired
func (j *DbCleanupJobs) clearOneTimeAccessTokens() error {
return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
func (j *DbCleanupJobs) clearOneTimeAccessTokens(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *DbCleanupJobs) clearOidcAuthorizationCodes() error {
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
func (j *DbCleanupJobs) clearOidcAuthorizationCodes(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *DbCleanupJobs) clearOidcRefreshTokens() error {
return j.db.Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
func (j *DbCleanupJobs) clearOidcRefreshTokens(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearAuditLogs deletes audit logs older than 90 days
func (j *DbCleanupJobs) clearAuditLogs() error {
return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error
func (j *DbCleanupJobs) clearAuditLogs(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).
Error
}

View File

@@ -1,6 +1,7 @@
package job
import (
"context"
"fmt"
"log"
"os"
@@ -8,12 +9,13 @@ import (
"strings"
"github.com/go-co-op/gocron/v2"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm"
)
func RegisterFileCleanupJobs(db *gorm.DB) {
func RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
@@ -21,7 +23,7 @@ func RegisterFileCleanupJobs(db *gorm.DB) {
jobs := &FileCleanupJobs{db: db}
registerJob(scheduler, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures)
registerJob(ctx, scheduler, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures)
scheduler.Start()
}
@@ -31,9 +33,13 @@ type FileCleanupJobs struct {
}
// ClearUnusedDefaultProfilePictures deletes default profile pictures that don't match any user's initials
func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures() error {
func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context) error {
var users []model.User
if err := j.db.Find(&users).Error; err != nil {
err := j.db.
WithContext(ctx).
Find(&users).
Error
if err != nil {
return fmt.Errorf("failed to fetch users: %w", err)
}

View File

@@ -1,16 +1,18 @@
package job
import (
"context"
"log"
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
)
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
func registerJob(ctx context.Context, scheduler gocron.Scheduler, name string, interval string, job func(ctx context.Context) error) {
_, err := scheduler.NewJob(
gocron.CronJob(interval, false),
gocron.NewTask(job),
gocron.WithContext(ctx),
gocron.WithEventListeners(
gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) {
log.Printf("Job %q run successfully", name)

View File

@@ -1,6 +1,7 @@
package job
import (
"context"
"log"
"github.com/go-co-op/gocron/v2"
@@ -12,28 +13,30 @@ type LdapJobs struct {
appConfigService *service.AppConfigService
}
func RegisterLdapJobs(ldapService *service.LdapService, appConfigService *service.AppConfigService) {
func RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) {
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
log.Fatalf("Failed to create a new scheduler: %v", err)
}
// Register the job to run every hour
registerJob(scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap)
registerJob(ctx, scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap)
// Run the job immediately on startup
if err := jobs.syncLdap(); err != nil {
log.Printf("Failed to sync LDAP: %s", err)
err = jobs.syncLdap(ctx)
if err != nil {
log.Printf("Failed to sync LDAP: %v", err)
}
scheduler.Start()
}
func (j *LdapJobs) syncLdap() error {
if j.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return j.ldapService.SyncAll()
}
func (j *LdapJobs) syncLdap(ctx context.Context) error {
if !j.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return nil
}
return j.ldapService.SyncAll(ctx)
}

View File

@@ -36,7 +36,7 @@ func (m *ApiKeyAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) {
apiKey := c.GetHeader("X-API-KEY")
user, err := m.apiKeyService.ValidateApiKey(apiKey)
user, err := m.apiKeyService.ValidateApiKey(c.Request.Context(), apiKey)
if err != nil {
return "", false, &common.NotSignedInError{}
}

View File

@@ -1,8 +1,8 @@
package service
import (
"context"
"errors"
"log"
"time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
@@ -12,6 +12,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type ApiKeyService struct {
@@ -22,8 +23,11 @@ func NewApiKeyService(db *gorm.DB) *ApiKeyService {
return &ApiKeyService{db: db}
}
func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{})
func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.
WithContext(ctx).
Where("user_id = ?", userID).
Model(&model.ApiKey{})
var apiKeys []model.ApiKey
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
@@ -34,7 +38,7 @@ func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils
return apiKeys, pagination, nil
}
func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
// Check if expiration is in the future
if !input.ExpiresAt.ToTime().After(time.Now()) {
return model.ApiKey{}, "", &common.APIKeyExpirationDateError{}
@@ -54,7 +58,11 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
UserID: userID,
}
if err := s.db.Create(&apiKey).Error; err != nil {
err = s.db.
WithContext(ctx).
Create(&apiKey).
Error
if err != nil {
return model.ApiKey{}, "", err
}
@@ -62,29 +70,44 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
return apiKey, token, nil
}
func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error {
func (s *ApiKeyService) RevokeApiKey(ctx context.Context, userID, apiKeyID string) error {
var apiKey model.ApiKey
if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil {
err := s.db.
WithContext(ctx).
Where("id = ? AND user_id = ?", apiKeyID, userID).
Delete(&apiKey).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return &common.APIKeyNotFoundError{}
}
return err
}
return s.db.Delete(&apiKey).Error
return nil
}
func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
func (s *ApiKeyService) ValidateApiKey(ctx context.Context, apiKey string) (model.User, error) {
if apiKey == "" {
return model.User{}, &common.NoAPIKeyProvidedError{}
}
var key model.ApiKey
now := time.Now()
hashedKey := utils.CreateSha256Hash(apiKey)
if err := s.db.Preload("User").Where("key = ? AND expires_at > ?",
hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil {
var key model.ApiKey
err := s.db.
WithContext(ctx).
Model(&model.ApiKey{}).
Clauses(clause.Returning{}).
Where("key = ? AND expires_at > ?", hashedKey, datatype.DateTime(now)).
Updates(&model.ApiKey{
LastUsedAt: utils.Ptr(datatype.DateTime(now)),
}).
Preload("User").
First(&key).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, &common.InvalidAPIKeyError{}
}
@@ -92,12 +115,5 @@ func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
return model.User{}, err
}
// Update last used time
now := datatype.DateTime(time.Now())
key.LastUsedAt = &now
if err := s.db.Save(&key).Error; err != nil {
log.Printf("Failed to update last used time: %v", err)
}
return key.User, nil
}

View File

@@ -1,7 +1,8 @@
package service
import (
"fmt"
"context"
"errors"
"log"
"mime/multipart"
"os"
@@ -19,12 +20,14 @@ type AppConfigService struct {
db *gorm.DB
}
func NewAppConfigService(db *gorm.DB) *AppConfigService {
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{
DbConfig: &defaultDbConfig,
db: db,
}
if err := service.InitDbConfig(); err != nil {
err := service.InitDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
@@ -197,17 +200,24 @@ var defaultDbConfig = model.AppConfig{
},
}
func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
if common.EnvConfig.UiConfigDisabled {
return nil, &common.UiConfigDisabledError{}
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var err error
rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input)
var savedConfigVariables []model.AppConfigVariable
for i := 0; i < rt.NumField(); i++ {
savedConfigVariables := make([]model.AppConfigVariable, 0, rt.NumField())
for i := range rt.NumField() {
field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()
@@ -220,32 +230,47 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
}
var appConfigVariable model.AppConfigVariable
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
err = tx.
WithContext(ctx).
First(&appConfigVariable, "key = ? AND is_internal = false", key).
Error
if err != nil {
return nil, err
}
appConfigVariable.Value = value
if err := tx.Save(&appConfigVariable).Error; err != nil {
tx.Rollback()
err = tx.
WithContext(ctx).
Save(&appConfigVariable).
Error
if err != nil {
return nil, err
}
savedConfigVariables = append(savedConfigVariables, appConfigVariable)
}
tx.Commit()
err = tx.Commit().Error
if err != nil {
return nil, err
}
if err := s.LoadDbConfigFromDb(); err != nil {
err = s.LoadDbConfigFromDb()
if err != nil {
return nil, err
}
return savedConfigVariables, nil
}
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
key := fmt.Sprintf("%sImageType", imageName)
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
func (s *AppConfigService) updateImageType(ctx context.Context, imageName string, fileType string) error {
key := imageName + "ImageType"
err := s.db.
WithContext(ctx).
Model(&model.AppConfigVariable{}).
Where("key = ?", key).
Update("value", fileType).
Error
if err != nil {
return err
}
@@ -253,14 +278,17 @@ func (s *AppConfigService) UpdateImageType(imageName string, fileType string) er
return s.LoadDbConfigFromDb()
}
func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) {
var configuration []model.AppConfigVariable
var err error
func (s *AppConfigService) ListAppConfig(ctx context.Context, showAll bool) (configuration []model.AppConfigVariable, err error) {
if showAll {
err = s.db.Find(&configuration).Error
err = s.db.
WithContext(ctx).
Find(&configuration).
Error
} else {
err = s.db.Find(&configuration, "is_public = true").Error
err = s.db.
WithContext(ctx).
Find(&configuration, "is_public = true").
Error
}
if err != nil {
@@ -271,7 +299,6 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
if common.EnvConfig.UiConfigDisabled {
// Set the value to the environment variable if the UI config is disabled
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" {
// Set the value to the default value if it is empty
configuration[i].Value = configuration[i].DefaultValue
@@ -281,7 +308,7 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
return configuration, nil
}
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error {
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
fileType := utils.GetFileExtension(uploadedFile.Filename)
mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" {
@@ -290,19 +317,22 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
// Delete the old image if it has a different file type
if fileType != oldImageType {
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
if err := os.Remove(oldImagePath); err != nil {
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
err = os.Remove(oldImagePath)
if err != nil {
return err
}
}
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
if err := utils.SaveFile(uploadedFile, imagePath); err != nil {
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
err = utils.SaveFile(uploadedFile, imagePath)
if err != nil {
return err
}
// Update the file type in the database
if err := s.UpdateImageType(imageName, fileType); err != nil {
err = s.updateImageType(ctx, imageName, fileType)
if err != nil {
return err
}
@@ -312,33 +342,58 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
// InitDbConfig creates the default configuration values in the database if they do not exist,
// updates existing configurations if they differ from the default, and deletes any configurations
// that are not in the default configuration.
func (s *AppConfigService) InitDbConfig() error {
func (s *AppConfigService) InitDbConfig(ctx context.Context) (err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig
for i := 0; i < defaultConfigReflectValue.NumField(); i++ {
for i := range defaultConfigReflectValue.NumField() {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
defaultKeys[defaultConfigVar.Key] = struct{}{}
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
err = tx.
WithContext(ctx).
First(&storedConfigVar, "key = ?", defaultConfigVar.Key).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the configuration does not exist, create it
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&defaultConfigVar).
Error
if err != nil {
return err
}
continue
} else if err != nil {
return err
}
// Update existing configuration if it differs from the default
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
if storedConfigVar.Type != defaultConfigVar.Type ||
storedConfigVar.IsPublic != defaultConfigVar.IsPublic ||
storedConfigVar.IsInternal != defaultConfigVar.IsInternal ||
storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
// Set values
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue
if err := s.db.Save(&storedConfigVar).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&storedConfigVar).
Error
if err != nil {
return err
}
}
@@ -346,29 +401,54 @@ func (s *AppConfigService) InitDbConfig() error {
// Delete any configurations not in the default keys
var allConfigVars []model.AppConfigVariable
if err := s.db.Find(&allConfigVars).Error; err != nil {
err = tx.
WithContext(ctx).
Find(&allConfigVars).
Error
if err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; !exists {
if err := s.db.Delete(&config).Error; err != nil {
if _, exists := defaultKeys[config.Key]; exists {
continue
}
err = tx.
WithContext(ctx).
Delete(&config).
Error
if err != nil {
return err
}
}
// Commit the changes
err = tx.Commit().Error
if err != nil {
return err
}
return s.LoadDbConfigFromDb()
// Reload the configuration
err = s.LoadDbConfigFromDb()
if err != nil {
return err
}
return nil
}
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfigFromDb() error {
return s.db.Transaction(func(tx *gorm.DB) error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
for i := range dbConfigReflectValue.NumField() {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
err := tx.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error
if err != nil {
return err
}
@@ -379,10 +459,10 @@ func (s *AppConfigService) LoadDbConfigFromDb() error {
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
})
}
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {

View File

@@ -25,10 +25,10 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
}
// Create creates a new audit log entry in the database
func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData) model.AuditLog {
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
if err != nil {
log.Printf("Failed to get IP location: %v\n", err)
log.Printf("Failed to get IP location: %v", err)
}
auditLog := model.AuditLog{
@@ -42,8 +42,12 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
}
// Save the audit log in the database
if err := s.db.Create(&auditLog).Error; err != nil {
log.Printf("Failed to create audit log: %v\n", err)
err = tx.
WithContext(ctx).
Create(&auditLog).
Error
if err != nil {
log.Printf("Failed to create audit log: %v", err)
return model.AuditLog{}
}
@@ -51,12 +55,17 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
}
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID string) model.AuditLog {
createdAuditLog := s.Create(model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{})
func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
// Count the number of times the user has logged in from the same device
var count int64
err := s.db.Model(&model.AuditLog{}).Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).Count(&count).Error
err := tx.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).
Count(&count).
Error
if err != nil {
log.Printf("Failed to count audit logs: %v\n", err)
return createdAuditLog
@@ -64,11 +73,23 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {
var user model.User
s.db.Where("id = ?", userID).First(&user)
innerCtx := context.Background()
err := SendEmail(s.emailService, email.Address{
// Note we don't use the transaction here because this is running in background
var user model.User
innerErr := s.db.
WithContext(innerCtx).
Where("id = ?", userID).
First(&user).
Error
if innerErr != nil {
log.Printf("Failed to load user: %v", innerErr)
}
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
Name: user.Username,
Email: user.Email,
}, NewLoginTemplate, &NewLoginTemplateData{
@@ -78,8 +99,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
Device: s.DeviceStringFromUserAgent(userAgent),
DateTime: createdAuditLog.CreatedAt.UTC(),
})
if err != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
if innerErr != nil {
log.Printf("Failed to send email to '%s': %v", user.Email, innerErr)
}
}()
}
@@ -88,9 +109,12 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
}
// ListAuditLogsForUser retrieves all audit logs for a given user ID
func (s *AuditLogService) ListAuditLogsForUser(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
var logs []model.AuditLog
query := s.db.Model(&model.AuditLog{}).Where("user_id = ?", userID)
query := s.db.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ?", userID)
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
return logs, pagination, err
@@ -162,19 +186,19 @@ func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[s
}
func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []string, err error) {
dialect := s.db.Name()
query := s.db.
WithContext(ctx).
Model(&model.AuditLog{})
dialect := s.db.Name()
switch dialect {
case "sqlite":
query = query.
Select("DISTINCT json_extract(data, '$.clientName') as client_name").
Select("DISTINCT json_extract(data, '$.clientName') AS client_name").
Where("json_extract(data, '$.clientName') IS NOT NULL")
case "postgres":
query = query.
Select("DISTINCT data->>'clientName' as client_name").
Select("DISTINCT data->>'clientName' AS client_name").
Where("data->>'clientName' IS NOT NULL")
default:
return nil, fmt.Errorf("unsupported database dialect: %s", dialect)

View File

@@ -1,34 +1,14 @@
package service
import (
"context"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm"
)
// Reserved claims
var reservedClaims = map[string]struct{}{
"given_name": {},
"family_name": {},
"name": {},
"email": {},
"preferred_username": {},
"groups": {},
"sub": {},
"iss": {},
"aud": {},
"exp": {},
"iat": {},
"auth_time": {},
"nonce": {},
"acr": {},
"amr": {},
"azp": {},
"nbf": {},
"jti": {},
}
type CustomClaimService struct {
db *gorm.DB
}
@@ -39,8 +19,29 @@ func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
func isReservedClaim(key string) bool {
_, ok := reservedClaims[key]
return ok
switch key {
case "given_name",
"family_name",
"name",
"email",
"preferred_username",
"groups",
"sub",
"iss",
"aud",
"exp",
"iat",
"auth_time",
"nonce",
"acr",
"amr",
"azp",
"nbf",
"jti":
return true
default:
return false
}
}
// idType is the type of the id used to identify the user or user group
@@ -52,28 +53,38 @@ const (
)
// UpdateCustomClaimsForUser updates the custom claims for a user
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserID, userID, claims)
func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(ctx, UserID, userID, claims)
}
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserGroupID, userGroupID, claims)
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(ctx, UserGroupID, userGroupID, claims)
}
// updateCustomClaims updates the custom claims for a user or user group
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
func (s *CustomClaimService) updateCustomClaims(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
// Check for duplicate keys in the claims slice
seenKeys := make(map[string]bool)
seenKeys := make(map[string]struct{})
for _, claim := range claims {
if seenKeys[claim.Key] {
if _, ok := seenKeys[claim.Key]; ok {
return nil, &common.DuplicateClaimError{Key: claim.Key}
}
seenKeys[claim.Key] = true
seenKeys[claim.Key] = struct{}{}
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var existingClaims []model.CustomClaim
err := s.db.Where(string(idType), value).Find(&existingClaims).Error
err := tx.
WithContext(ctx).
Where(string(idType), value).
Find(&existingClaims).
Error
if err != nil {
return nil, err
}
@@ -87,8 +98,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
break
}
}
if !found {
err = s.db.Delete(&existingClaim).Error
err = tx.
WithContext(ctx).
Delete(&existingClaim).
Error
if err != nil {
return nil, err
}
@@ -113,7 +128,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
}
// Update the claim if it already exists or create a new one
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error
err = tx.
WithContext(ctx).
Where(string(idType)+" = ? AND key = ?", value, claim.Key).
Assign(&customClaim).
FirstOrCreate(&model.CustomClaim{}).
Error
if err != nil {
return nil, err
}
@@ -121,7 +141,16 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
// Get the updated claims
var updatedClaims []model.CustomClaim
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error
err = tx.
WithContext(ctx).
Where(string(idType)+" = ?", value).
Find(&updatedClaims).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
@@ -129,23 +158,31 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
return updatedClaims, nil
}
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) {
func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error
err := tx.
WithContext(ctx).
Where("user_id = ?", userID).
Find(&customClaims).
Error
return customClaims, err
}
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) {
func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error
err := tx.
WithContext(ctx).
Where("user_group_id = ?", userGroupID).
Find(&customClaims).
Error
return customClaims, err
}
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
// prioritizing the user's claims over user group claims with the same key.
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) {
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
// Get the custom claims of the user
customClaims, err := s.GetCustomClaimsForUser(userID)
customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx)
if err != nil {
return nil, err
}
@@ -158,7 +195,9 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
// Get all user groups of the user
var userGroupsOfUser []model.UserGroup
err = s.db.Preload("CustomClaims").
err = tx.
WithContext(ctx).
Preload("CustomClaims").
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
Where("user_groups_users.user_id = ?", userID).
Find(&userGroupsOfUser).Error
@@ -186,10 +225,12 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
}
// GetSuggestions returns a list of custom claim keys that have been used before
func (s *CustomClaimService) GetSuggestions() ([]string, error) {
func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) {
var customClaimsKeys []string
err := s.db.Model(&model.CustomClaim{}).
err := s.db.
WithContext(ctx).
Model(&model.CustomClaim{}).
Group("key").
Order("COUNT(*) DESC").
Pluck("key", &customClaimsKeys).Error

View File

@@ -3,6 +3,7 @@
package service
import (
"context"
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
@@ -34,6 +35,7 @@ func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService}
}
//nolint:gocognit
func (s *TestService) SeedDatabase() error {
return s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{
@@ -187,11 +189,8 @@ func (s *TestService) SeedDatabase() error {
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
if err != nil {
return err
}
publicKeyPasskey1, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKeyPasskey2, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
webauthnCredentials := []model.WebauthnCredential{
{
Name: "Passkey 1",
@@ -303,7 +302,7 @@ func (s *TestService) ResetApplicationImages() error {
func (s *TestService) ResetAppConfig() error {
// Reseed the config variables
if err := s.appConfigService.InitDbConfig(); err != nil {
if err := s.appConfigService.InitDbConfig(context.Background()); err != nil {
return err
}
@@ -320,7 +319,7 @@ func (s *TestService) SetJWTKeys() {
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
s.jwtService.SetKey(privateKey)
_ = s.jwtService.SetKey(privateKey)
}
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key

View File

@@ -2,10 +2,12 @@ package service
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
htemplate "html/template"
"io"
"mime/multipart"
"mime/quotedprintable"
"net/textproto"
@@ -17,10 +19,11 @@ import (
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/google/uuid"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
)
type EmailService struct {
@@ -49,20 +52,24 @@ func NewEmailService(appConfigService *AppConfigService, db *gorm.DB) (*EmailSer
}, nil
}
func (srv *EmailService) SendTestEmail(recipientUserId string) error {
func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId string) error {
var user model.User
if err := srv.db.First(&user, "id = ?", recipientUserId).Error; err != nil {
err := srv.db.
WithContext(ctx).
First(&user, "id = ?", recipientUserId).
Error
if err != nil {
return err
}
return SendEmail(srv,
return SendEmail(ctx, srv,
email.Address{
Email: user.Email,
Name: user.FullName(),
}, TestTemplate, nil)
}
func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
@@ -112,6 +119,15 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
c.Body(body)
// Check if the context is still valid before attemtping to connect
// We need to do this because the smtp library doesn't have context support
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Connect to the SMTP server
client, err := srv.getSmtpClient()
if err != nil {
@@ -119,6 +135,14 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
}
defer client.Close()
// Check if the context is still valid before sending the email
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Send the email
if err := srv.sendEmailContent(client, toEmail, c); err != nil {
return fmt.Errorf("send email content: %w", err)
@@ -215,7 +239,7 @@ func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Add
}
// Write the email content
_, err = w.Write([]byte(c.String()))
_, err = io.Copy(w, strings.NewReader(c.String()))
if err != nil {
return fmt.Errorf("failed to write email data: %w", err)
}

View File

@@ -42,7 +42,7 @@ var tailscaleIPNets = []*net.IPNet{
}
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
func NewGeoLiteService() *GeoLiteService {
func NewGeoLiteService(ctx context.Context) *GeoLiteService {
service := &GeoLiteService{}
if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl {
@@ -52,8 +52,9 @@ func NewGeoLiteService() *GeoLiteService {
}
go func() {
if err := service.updateDatabase(); err != nil {
log.Printf("Failed to update GeoLite2 City database: %v\n", err)
err := service.updateDatabase(ctx)
if err != nil {
log.Printf("Failed to update GeoLite2 City database: %v", err)
}
}()
@@ -111,7 +112,7 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
}
// UpdateDatabase checks the age of the database and updates it if it's older than 14 days.
func (s *GeoLiteService) updateDatabase() error {
func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error {
if s.disableUpdater {
// Avoid updating the GeoLite2 City database.
return nil
@@ -125,7 +126,7 @@ func (s *GeoLiteService) updateDatabase() error {
log.Println("Updating GeoLite2 City database...")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)

View File

@@ -38,7 +38,9 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
// Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
}))
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
}
@@ -53,22 +55,31 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
return client, nil
}
func (s *LdapService) SyncAll() error {
err := s.SyncUsers()
func (s *LdapService) SyncAll(ctx context.Context) error {
// Start a transaction
tx := s.db.Begin()
err := s.SyncUsers(ctx, tx)
if err != nil {
return fmt.Errorf("failed to sync users: %w", err)
}
err = s.SyncGroups()
err = s.SyncGroups(ctx, tx)
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
// Commit the changes
err = tx.Commit().Error
if err != nil {
return fmt.Errorf("failed to commit changes to database: %w", err)
}
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups() error {
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -112,7 +123,7 @@ func (s *LdapService) SyncGroups() error {
// Try to find the group in the database
var databaseGroup model.UserGroup
s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup)
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
// Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(groupMemberOfAttribute)
@@ -122,7 +133,7 @@ func (s *LdapService) SyncGroups() error {
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
var databaseUser model.User
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// The user collides with a non-LDAP user, so we skip it
@@ -143,39 +154,51 @@ func (s *LdapService) SyncGroups() error {
}
if databaseGroup.ID == "" {
newGroup, err := s.groupService.Create(syncGroup)
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
} else {
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
return err
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
} else {
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
}
}
// Get all LDAP groups from the database
var ldapGroupsInDb []model.UserGroup
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err))
err = tx.
WithContext(ctx).
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch groups from database: %v", err)
}
// Delete groups that no longer exist in LDAP
for _, group := range ldapGroupsInDb {
if _, exists := ldapGroupIDs[*group.LdapID]; !exists {
if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil {
err = tx.
WithContext(ctx).
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
Error
if err != nil {
log.Printf("Failed to delete group %s with: %v", group.Name, err)
} else {
log.Printf("Deleted group %s", group.Name)
@@ -187,7 +210,7 @@ func (s *LdapService) SyncGroups() error {
}
//nolint:gocognit
func (s *LdapService) SyncUsers() error {
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -241,7 +264,7 @@ func (s *LdapService) SyncUsers() error {
// Get the user from the database
var databaseUser model.User
s.db.Where("ldap_id = ?", ldapId).First(&databaseUser)
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser)
// Check if user is admin by checking if they are in the admin group
isAdmin := false
@@ -261,68 +284,75 @@ func (s *LdapService) SyncUsers() error {
}
if databaseUser.ID == "" {
_, err = s.userService.CreateUser(newUser)
_, err = s.userService.createUserInternal(ctx, newUser, tx)
if err != nil {
log.Printf("Error syncing user %s: %s", newUser.Username, err)
log.Printf("Error syncing user %s: %v", newUser.Username, err)
}
} else {
_, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true)
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if err != nil {
log.Printf("Error syncing user %s: %s", newUser.Username, err)
log.Printf("Error syncing user %s: %v", newUser.Username, err)
}
}
// Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" {
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err)
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
}
}
}
// Get all LDAP users from the database
var ldapUsersInDb []model.User
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err))
err = tx.
WithContext(ctx).
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch users from database: %v", err)
}
// Delete users that no longer exist in LDAP
for _, user := range ldapUsersInDb {
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
if err := s.userService.DeleteUser(user.ID, true); err != nil {
if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil {
log.Printf("Failed to delete user %s with: %v", user.Username, err)
} else {
log.Printf("Deleted user %s", user.Username)
}
}
}
return nil
}
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error {
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := url.ParseRequestURI(pictureString)
if err == nil {
ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
var req *http.Request
req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
response, err := http.DefaultClient.Do(req)
var res *http.Response
res, err = http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err)
}
defer response.Body.Close()
reader = response.Body
defer res.Body.Close()
reader = res.Body
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
// If the photo is a base64 encoded string, decode it
reader = bytes.NewReader(decodedPhoto)
} else {
// If the photo is a string, we assume that it's a binary string
reader = bytes.NewReader([]byte(pictureString))

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
@@ -9,6 +10,7 @@ import (
"mime/multipart"
"os"
"regexp"
"slices"
"strings"
"time"
@@ -39,9 +41,20 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppCo
}
}
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.Preload("AllowedUserGroups").First(&client, "id = ?", input.ClientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("AllowedUserGroups").
First(&client, "id = ?", input.ClientID).
Error
if err != nil {
return "", "", err
}
@@ -58,7 +71,12 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
// Check if the user group is allowed to authorize the client
var user model.User
if err := s.db.Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("UserGroups").
First(&user, "id = ?", userID).
Error
if err != nil {
return "", "", err
}
@@ -67,7 +85,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
}
// Check if the user has already authorized the client with the given scope
hasAuthorizedClient, err := s.HasAuthorizedClient(input.ClientID, userID, input.Scope)
hasAuthorizedClient, err := s.hasAuthorizedClientInternal(ctx, input.ClientID, userID, input.Scope, tx)
if err != nil {
return "", "", err
}
@@ -80,39 +98,55 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
Scope: input.Scope,
}
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&userAuthorizedClient).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
// The client has already been authorized but with a different scope so we need to update the scope
if err := s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
if err := tx.
WithContext(ctx).
Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
return "", "", err
}
} else {
} else if err != nil {
return "", "", err
}
}
}
// Create the authorization code
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
if err != nil {
return "", "", err
}
// Log the authorization event
if hasAuthorizedClient {
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
s.auditLogService.Create(ctx, model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
} else {
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
s.auditLogService.Create(ctx, model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
}
err = tx.Commit().Error
if err != nil {
return "", "", err
}
return code, callbackURL, nil
}
// HasAuthorizedClient checks if the user has already authorized the client with the given scope
func (s *OidcService) HasAuthorizedClient(clientID, userID, scope string) (bool, error) {
func (s *OidcService) HasAuthorizedClient(ctx context.Context, clientID, userID, scope string) (bool, error) {
return s.hasAuthorizedClientInternal(ctx, clientID, userID, scope, s.db)
}
func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID, userID, scope string, tx *gorm.DB) (bool, error) {
var userAuthorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
@@ -145,21 +179,31 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
func (s *OidcService) CreateTokens(ctx context.Context, code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
switch grantType {
case "authorization_code":
return s.createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier)
return s.createTokenFromAuthorizationCode(ctx, code, clientID, clientSecret, codeVerifier)
case "refresh_token":
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(refreshToken, clientID, clientSecret)
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, refreshToken, clientID, clientSecret)
return "", accessToken, newRefreshToken, exp, err
default:
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
}
}
func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", "", 0, err
}
@@ -176,7 +220,11 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
err = tx.
WithContext(ctx).
Preload("User").
First(&authorizationCodeMetaData, "code = ?", code).
Error
if err != nil {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
@@ -192,7 +240,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx)
if err != nil {
return "", "", "", 0, err
}
@@ -203,7 +251,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
}
// Generate a refresh token
refreshToken, err = s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
if err != nil {
return "", "", "", 0, err
}
@@ -213,19 +261,40 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
return "", "", "", 0, err
}
s.db.Delete(&authorizationCodeMetaData)
err = tx.
WithContext(ctx).
Delete(&authorizationCodeMetaData).
Error
if err != nil {
return "", "", "", 0, err
}
err = tx.Commit().Error
if err != nil {
return "", "", "", 0, err
}
return idToken, accessToken, refreshToken, 3600, nil
}
func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
if refreshToken == "" {
return "", "", 0, &common.OidcMissingRefreshTokenError{}
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
// Get the client to check if it's public
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", 0, err
}
@@ -243,7 +312,9 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
err = s.db.Preload("User").
err = tx.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
@@ -266,29 +337,53 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err = s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
if err != nil {
return "", "", 0, err
}
// Delete the used refresh token
s.db.Delete(&storedRefreshToken)
err = tx.
WithContext(ctx).
Delete(&storedRefreshToken).
Error
if err != nil {
return "", "", 0, err
}
err = tx.Commit().Error
if err != nil {
return "", "", 0, err
}
return accessToken, newRefreshToken, 3600, nil
}
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
return s.getClientInternal(ctx, clientID, s.db)
}
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
var client model.OidcClient
if err := s.db.Preload("CreatedBy").Preload("AllowedUserGroups").First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("CreatedBy").
Preload("AllowedUserGroups").
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
func (s *OidcService) ListClients(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
var clients []model.OidcClient
query := s.db.Preload("CreatedBy").Model(&model.OidcClient{})
query := s.db.
WithContext(ctx).
Preload("CreatedBy").
Model(&model.OidcClient{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("name LIKE ?", searchPattern)
@@ -302,7 +397,7 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti
return clients, pagination, nil
}
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{
Name: input.Name,
CallbackURLs: input.CallbackURLs,
@@ -312,16 +407,31 @@ func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string)
PkceEnabled: input.IsPublic || input.PkceEnabled,
}
if err := s.db.Create(&client).Error; err != nil {
err := s.db.
WithContext(ctx).
Create(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("CreatedBy").
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err
}
@@ -331,29 +441,49 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
client.IsPublic = input.IsPublic
client.PkceEnabled = input.IsPublic || input.PkceEnabled
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
err = tx.Commit().Error
if err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) DeleteClient(clientID string) error {
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if err := s.db.Delete(&client).Error; err != nil {
err := s.db.
WithContext(ctx).
Where("id = ?", clientID).
Delete(&client).
Error
if err != nil {
return err
}
return nil
}
func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
func (s *OidcService) CreateClientSecret(ctx context.Context, clientID string) (string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", err
}
@@ -368,16 +498,29 @@ func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
}
client.Secret = string(hashedSecret)
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return "", err
}
err = tx.Commit().Error
if err != nil {
return "", err
}
return clientSecret, nil
}
func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
func (s *OidcService) GetClientLogo(ctx context.Context, clientID string) (string, string, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err := s.db.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", err
}
@@ -385,26 +528,36 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
return "", "", errors.New("image not found")
}
imageType := *client.ImageType
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
mimeType := utils.GetImageMimeType(imageType)
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
mimeType := utils.GetImageMimeType(*client.ImageType)
return imagePath, mimeType, nil
}
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader) error {
fileType := utils.GetFileExtension(file.Filename)
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
return &common.FileTypeNotSupportedError{}
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
if err := utils.SaveFile(file, imagePath); err != nil {
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + fileType
err := utils.SaveFile(file, imagePath)
if err != nil {
return err
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err
}
@@ -416,16 +569,35 @@ func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHead
}
client.ImageType = &fileType
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
err = tx.Commit().Error
if err != nil {
return err
}
return nil
}
func (s *OidcService) DeleteClientLogo(clientID string) error {
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err
}
@@ -433,38 +605,72 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {
return errors.New("image not found")
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
client.ImageType = nil
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
if err := os.Remove(imagePath); err != nil {
return err
}
client.ImageType = nil
if err := s.db.Save(&client).Error; err != nil {
err = tx.Commit().Error
if err != nil {
return err
}
return nil
}
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
}
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.Preload("User.UserGroups").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("User.UserGroups").
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
Error
if err != nil {
return nil, err
}
user := authorizedOidcClient.User
scope := authorizedOidcClient.Scope
scopes := strings.Split(authorizedOidcClient.Scope, " ")
claims := map[string]interface{}{
"sub": user.ID,
}
if strings.Contains(scope, "email") {
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue()
}
if strings.Contains(scope, "groups") {
if slices.Contains(scopes, "groups") {
userGroups := make([]string, len(user.UserGroups))
for i, group := range user.UserGroups {
userGroups[i] = group.Name
@@ -477,17 +683,17 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": fmt.Sprintf("%s/api/users/%s/profile-picture.png", common.EnvConfig.AppURL, user.ID),
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
}
if strings.Contains(scope, "profile") {
if slices.Contains(scopes, "profile") {
// Add profile claims
for k, v := range profileClaims {
claims[k] = v
}
// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID)
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx)
if err != nil {
return nil, err
}
@@ -505,15 +711,22 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
}
}
}
if strings.Contains(scope, "email") {
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
}
return claims, nil
}
func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
client, err = s.GetClient(id)
func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
client, err = s.getClientInternal(ctx, id, tx)
if err != nil {
return model.OidcClient{}, err
}
@@ -521,18 +734,37 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
// Fetch the user groups based on UserGroupIDs in input
var groups []model.UserGroup
if len(input.UserGroupIDs) > 0 {
if err := s.db.Where("id IN (?)", input.UserGroupIDs).Find(&groups).Error; err != nil {
err = tx.
WithContext(ctx).
Where("id IN (?)", input.UserGroupIDs).
Find(&groups).
Error
if err != nil {
return model.OidcClient{}, err
}
}
// Replace the current user groups with the new set of user groups
if err := s.db.Model(&client).Association("AllowedUserGroups").Replace(groups); err != nil {
err = tx.
WithContext(ctx).
Model(&client).
Association("AllowedUserGroups").
Replace(groups)
if err != nil {
return model.OidcClient{}, err
}
// Save the updated client
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
err = tx.Commit().Error
if err != nil {
return model.OidcClient{}, err
}
@@ -540,7 +772,7 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
}
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) {
func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogoutDto, userID string) (string, error) {
// If no ID token hint is provided, return an error
if input.IdTokenHint == "" {
return "", &common.TokenInvalidError{}
@@ -564,7 +796,12 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
// Check if the user has authorized the client before
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).Error; err != nil {
err = s.db.
WithContext(ctx).
Preload("Client").
First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).
Error
if err != nil {
return "", &common.OidcMissingAuthorizationError{}
}
@@ -582,7 +819,7 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
@@ -601,7 +838,11 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
}
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&oidcAuthorizationCode).
Error
if err != nil {
return "", err
}
@@ -647,7 +888,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{}
}
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
@@ -665,7 +906,11 @@ func (s *OidcService) createRefreshToken(clientID string, userID string, scope s
Scope: scope,
}
if err := s.db.Create(&m).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&m).
Error
if err != nil {
return "", err
}

View File

@@ -1,13 +1,15 @@
package service
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type UserGroupService struct {
@@ -19,8 +21,11 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG
return &UserGroupService{db: db, appConfigService: appConfigService}
}
func (s *UserGroupService) List(name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{})
func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.
WithContext(ctx).
Preload("CustomClaims").
Model(&model.UserGroup{})
if name != "" {
query = query.Where("name LIKE ?", "%"+name+"%")
@@ -42,26 +47,59 @@ func (s *UserGroupService) List(name string, sortedPaginationRequest utils.Sorte
return groups, response, err
}
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) {
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error
func (s *UserGroupService) Get(ctx context.Context, id string) (group model.UserGroup, err error) {
return s.getInternal(ctx, id, s.db)
}
func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.DB) (group model.UserGroup, err error) {
err = tx.
WithContext(ctx).
Where("id = ?", id).
Preload("CustomClaims").
Preload("Users").
First(&group).
Error
return group, err
}
func (s *UserGroupService) Delete(id string) error {
func (s *UserGroupService) Delete(ctx context.Context, id string) error {
tx := s.db.Begin()
var group model.UserGroup
if err := s.db.Where("id = ?", id).First(&group).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ?", id).
First(&group).
Error
if err != nil {
return err
}
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
err = tx.Rollback().Error
if err != nil {
return err
}
return &common.LdapUserGroupUpdateError{}
}
return s.db.Delete(&group).Error
err = tx.
WithContext(ctx).
Delete(&group).
Error
if err != nil {
return err
}
func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
return tx.Commit().Error
}
func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
return s.createInternal(ctx, input, s.db)
}
func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGroupCreateDto, tx *gorm.DB) (group model.UserGroup, err error) {
group = model.UserGroup{
FriendlyName: input.FriendlyName,
Name: input.Name,
@@ -71,7 +109,12 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
group.LdapID = &input.LdapID
}
if err := s.db.Preload("Users").Create(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("Users").
Create(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
}
@@ -80,8 +123,26 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
return group, nil
}
func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
group, err = s.Get(id)
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
tx := s.db.Begin()
group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx)
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
@@ -94,7 +155,12 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
group.Name = input.Name
group.FriendlyName = input.FriendlyName
if err := s.db.Preload("Users").Save(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("Users").
Save(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
}
@@ -103,8 +169,26 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
return group, nil
}
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) {
group, err = s.Get(id)
func (s *UserGroupService) UpdateUsers(ctx context.Context, id string, userIds []string) (group model.UserGroup, err error) {
tx := s.db.Begin()
group, err = s.updateUsersInternal(ctx, id, userIds, tx)
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, userIds []string, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
@@ -112,28 +196,59 @@ func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model
// Fetch the users based on the userIds
var users []model.User
if len(userIds) > 0 {
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id IN (?)", userIds).
Find(&users).
Error
if err != nil {
return model.UserGroup{}, err
}
}
// Replace the current users with the new set of users
if err := s.db.Model(&group).Association("Users").Replace(users); err != nil {
err = tx.
WithContext(ctx).
Model(&group).
Association("Users").
Replace(users)
if err != nil {
return model.UserGroup{}, err
}
// Save the updated group
if err := s.db.Save(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&group).
Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) {
func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (int64, error) {
// We only perform select queries here, so we can rollback in all cases
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var group model.UserGroup
if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("Users").
Where("id = ?", id).
First(&group).
Error
if err != nil {
return 0, err
}
return s.db.Model(&group).Association("Users").Count(), nil
count := tx.
WithContext(ctx).
Model(&group).
Association("Users").
Count()
return count, nil
}

View File

@@ -2,6 +2,7 @@ package service
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -12,7 +13,7 @@ import (
"time"
"github.com/google/uuid"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -20,7 +21,7 @@ import (
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
)
type UserService struct {
@@ -35,9 +36,9 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService}
}
func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
var users []model.User
query := s.db.Model(&model.User{})
query := s.db.WithContext(ctx).Model(&model.User{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
@@ -48,13 +49,23 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils
return users, pagination, err
}
func (s *UserService) GetUser(userID string) (model.User, error) {
func (s *UserService) GetUser(ctx context.Context, userID string) (model.User, error) {
return s.getUserInternal(ctx, userID, s.db)
}
func (s *UserService) getUserInternal(ctx context.Context, userID string, tx *gorm.DB) (model.User, error) {
var user model.User
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
err := tx.
WithContext(ctx).
Preload("UserGroups").
Preload("CustomClaims").
Where("id = ?", userID).
First(&user).
Error
return user, err
}
func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, error) {
func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.ReadCloser, int64, error) {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
return nil, 0, &common.InvalidUUIDError{}
@@ -74,7 +85,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er
}
// If no custom picture exists, get the user's data for creating initials
user, err := s.GetUser(userID)
user, err := s.GetUser(ctx, userID)
if err != nil {
return nil, 0, err
}
@@ -115,9 +126,15 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er
return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(defaultPicture.Len()), nil
}
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model.UserGroup, error) {
var user model.User
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil {
err := s.db.
WithContext(ctx).
Preload("UserGroups").
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return nil, err
}
return user.UserGroups, nil
@@ -152,9 +169,21 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
return nil
}
func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error {
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
return s.db.Transaction(func(tx *gorm.DB) error {
return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
})
}
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return err
}
@@ -165,14 +194,35 @@ func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error {
// Delete the profile picture
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) {
err = os.Remove(profilePicturePath)
if err != nil && !os.IsNotExist(err) {
return err
}
return s.db.Delete(&user).Error
return tx.WithContext(ctx).Delete(&user).Error
}
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
user, err := s.createUserInternal(ctx, input, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, tx *gorm.DB) (model.User, error) {
user := model.User{
FirstName: input.FirstName,
LastName: input.LastName,
@@ -185,18 +235,47 @@ func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
user.LdapID = &input.LdapID
}
if err := s.db.Create(&user).Error; err != nil {
err := tx.WithContext(ctx).Create(&user).Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.User{}, s.checkDuplicatedFields(user)
}
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
return model.User{}, err
} else if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
user, err := s.updateUserInternal(ctx, userID, updatedUser, updateOwnUser, allowLdapUpdate, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool, tx *gorm.DB) (model.User, error) {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return model.User{}, err
}
@@ -214,24 +293,42 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
user.IsAdmin = updatedUser.IsAdmin
}
if err := s.db.Save(&user).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&user).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
return user, s.checkDuplicatedFields(user)
}
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
return user, err
} else if err != nil {
return user, err
}
return user, nil
}
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error {
func (s *UserService) RequestOneTimeAccessEmail(ctx context.Context, emailAddress, redirectPath string) error {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
}
var user model.User
if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil {
err := tx.
WithContext(ctx).
Where("email = ?", emailAddress).
First(&user).
Error
if err != nil {
// Do not return error if user not found to prevent email enumeration
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
@@ -240,22 +337,31 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
}
}
oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(15*time.Minute))
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, time.Now().Add(15*time.Minute), tx)
if err != nil {
return err
}
link := fmt.Sprintf("%s/lc", common.EnvConfig.AppURL)
linkWithCode := fmt.Sprintf("%s/%s", link, oneTimeAccessToken)
err = tx.Commit().Error
if err != nil {
return err
}
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {
innerCtx := context.Background()
link := common.EnvConfig.AppURL + "/lc"
linkWithCode := link + "/" + oneTimeAccessToken
// Add redirect path to the link
if strings.HasPrefix(redirectPath, "/") {
encodedRedirectPath := url.QueryEscape(redirectPath)
linkWithCode = fmt.Sprintf("%s?redirect=%s", linkWithCode, encodedRedirectPath)
linkWithCode = linkWithCode + "?redirect=" + encodedRedirectPath
}
go func() {
err := SendEmail(s.emailService, email.Address{
errInternal := SendEmail(innerCtx, s.emailService, email.Address{
Name: user.Username,
Email: user.Email,
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
@@ -263,18 +369,21 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
LoginLink: link,
LoginLinkWithCode: linkWithCode,
})
if err != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
if errInternal != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, errInternal)
}
}()
return nil
}
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
tokenLength := 16
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) {
return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db)
}
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) {
// If expires at is less than 15 minutes, use an 6 character token instead of 16
tokenLength := 16
if time.Until(expiresAt) <= 15*time.Minute {
tokenLength = 6
}
@@ -290,16 +399,27 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim
Token: randomString,
}
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil {
if err := tx.WithContext(ctx).Create(&oneTimeAccessToken).Error; err != nil {
return "", err
}
return oneTimeAccessToken.Token, nil
}
func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) {
func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var oneTimeAccessToken model.OneTimeAccessToken
if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
err := tx.
WithContext(ctx).
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
First(&oneTimeAccessToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
}
@@ -310,19 +430,34 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
return model.User{}, "", err
}
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil {
err = tx.
WithContext(ctx).
Delete(&oneTimeAccessToken).
Error
if err != nil {
return model.User{}, "", err
}
if ipAddress != "" && userAgent != "" {
s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{})
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
}
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return oneTimeAccessToken.User, accessToken, nil
}
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) {
user, err = s.GetUser(id)
func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroupIds []string) (user model.User, err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
user, err = s.getUserInternal(ctx, id, tx)
if err != nil {
return model.User{}, err
}
@@ -330,27 +465,49 @@ func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user m
// Fetch the groups based on userGroupIds
var groups []model.UserGroup
if len(userGroupIds) > 0 {
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil {
err = tx.
WithContext(ctx).
Where("id IN (?)", userGroupIds).
Find(&groups).
Error
if err != nil {
return model.User{}, err
}
}
// Replace the current groups with the new set of groups
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil {
err = tx.
WithContext(ctx).
Model(&user).
Association("UserGroups").
Replace(groups)
if err != nil {
return model.User{}, err
}
// Save the updated user
if err := s.db.Save(&user).Error; err != nil {
err = tx.WithContext(ctx).Save(&user).Error
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var userCount int64
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
return model.User{}, "", err
}
if userCount > 1 {
@@ -365,7 +522,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
IsAdmin: true,
}
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
return model.User{}, "", err
}
@@ -378,16 +535,39 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
return model.User{}, "", err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return user, token, nil
}
func (s *UserService) checkDuplicatedFields(user model.User) error {
var existingUser model.User
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User, tx *gorm.DB) error {
var result struct {
Found bool
}
err := tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND email = ?) AS found`, user.ID, user.Email).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "email"}
}
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
err = tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND username = ?) AS found`, user.ID, user.Username).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "username"}
}

View File

@@ -1,16 +1,19 @@
package service
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type WebAuthnService struct {
@@ -43,15 +46,31 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService}
}
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) {
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
s.updateWebAuthnConfig()
var user model.User
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("Credentials").
Find(&user, "id = ?", userID).
Error
if err != nil {
tx.Rollback()
return nil, err
}
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
options, session, err := s.webAuthn.BeginRegistration(
&user,
webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()),
)
if err != nil {
return nil, err
}
@@ -62,7 +81,16 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
UserVerification: string(session.UserVerification),
}
if err := s.db.Create(&sessionToStore).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
@@ -73,9 +101,19 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
}, nil
}
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
@@ -86,7 +124,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
}
var user model.User
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
err = tx.
WithContext(ctx).
Find(&user, "id = ?", userID).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
@@ -108,7 +150,16 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState,
}
if err := s.db.Create(&credentialToStore).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&credentialToStore).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
err = tx.Commit().Error
if err != nil {
return model.WebauthnCredential{}, err
}
@@ -125,7 +176,7 @@ func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string {
return "New Passkey" // Default fallback
}
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
func (s *WebAuthnService) BeginLogin(ctx context.Context) (*model.PublicKeyCredentialRequestOptions, error) {
options, session, err := s.webAuthn.BeginDiscoverableLogin()
if err != nil {
return nil, err
@@ -137,7 +188,11 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
UserVerification: string(session.UserVerification),
}
if err := s.db.Create(&sessionToStore).Error; err != nil {
err = s.db.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err
}
@@ -148,9 +203,19 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}, nil
}
func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.User{}, "", err
}
@@ -160,9 +225,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
}
var user *model.User
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
return nil, err
_, err = s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
innerErr := tx.
WithContext(ctx).
Preload("Credentials").
First(&user, "id = ?", string(userHandle)).
Error
if innerErr != nil {
return nil, innerErr
}
return user, nil
}, session, credentialAssertionData)
@@ -176,41 +246,70 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
return model.User{}, "", err
}
s.auditLogService.CreateNewSignInWithEmail(ipAddress, userAgent, user.ID)
s.auditLogService.CreateNewSignInWithEmail(ctx, ipAddress, userAgent, user.ID, tx)
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return *user, token, nil
}
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([]model.WebauthnCredential, error) {
var credentials []model.WebauthnCredential
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil {
err := s.db.
WithContext(ctx).
Find(&credentials, "user_id = ?", userID).
Error
if err != nil {
return nil, err
}
return credentials, nil
}
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
var credential model.WebauthnCredential
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil {
return err
}
if err := s.db.Delete(&credential).Error; err != nil {
return err
func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error {
err := s.db.
WithContext(ctx).
Where("id = ? AND user_id = ?", credentialID, userID).
Delete(&model.WebauthnCredential{}).
Error
if err != nil {
return fmt.Errorf("failed to delete record: %w", err)
}
return nil
}
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credentialID, name string) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var credential model.WebauthnCredential
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ? AND user_id = ?", credentialID, userID).
First(&credential).
Error
if err != nil {
return credential, err
}
credential.Name = name
if err := s.db.Save(&credential).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&credential).
Error
if err != nil {
return credential, err
}
err = tx.Commit().Error
if err != nil {
return credential, err
}

View File

@@ -12,9 +12,13 @@ import (
var (
aaguidMap map[string]string
aaguidMapOnce sync.Once
aaguidMapOnce *sync.Once
)
func init() {
aaguidMapOnce = &sync.Once{}
}
// FormatAAGUID converts an AAGUID byte slice to UUID string format
func FormatAAGUID(aaguid []byte) string {
if len(aaguid) == 0 {

View File

@@ -47,8 +47,10 @@ func TestFormatAAGUID(t *testing.T) {
func TestGetAuthenticatorName(t *testing.T) {
// Reset the aaguidMap for testing
originalMap := aaguidMap
originalOnce := aaguidMapOnce
defer func() {
aaguidMap = originalMap
aaguidMapOnce = originalOnce
}()
// Inject a test AAGUID map
@@ -56,7 +58,7 @@ func TestGetAuthenticatorName(t *testing.T) {
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
}
aaguidMapOnce = sync.Once{}
aaguidMapOnce = &sync.Once{}
aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file
tests := []struct {
@@ -99,7 +101,7 @@ func TestGetAuthenticatorName(t *testing.T) {
func TestLoadAAGUIDsFromFile(t *testing.T) {
// Reset the map and once flag for clean testing
aaguidMap = nil
aaguidMapOnce = sync.Once{}
aaguidMapOnce = &sync.Once{}
// Trigger loading of AAGUIDs by calling GetAuthenticatorName
GetAuthenticatorName([]byte{0x01, 0x02, 0x03, 0x04})

View File

@@ -4,9 +4,8 @@ import (
"reflect"
"strconv"
"gorm.io/gorm/clause"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type PaginationResponse struct {
@@ -47,7 +46,6 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
}
return Paginate(pagination.Page, pagination.Limit, query, result)
}
func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) {

View File

@@ -0,0 +1,5 @@
package utils
func Ptr[T any](v T) *T {
return &v
}

Binary file not shown.