refactor: simplify app_config service and fix race conditions (#423)

This commit is contained in:
Alessandro (Ale) Segala
2025-04-10 04:41:22 -07:00
committed by GitHub
parent 4ba68938dd
commit f83bab9e17
28 changed files with 1263 additions and 560 deletions

View File

@@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version-file: backend/go.mod go-version-file: backend/go.mod

View File

@@ -11,7 +11,7 @@ RUN npm run build
RUN npm prune --production RUN npm prune --production
# Stage 2: Build Backend # Stage 2: Build Backend
FROM golang:1.23-alpine AS backend-builder FROM golang:1.24-alpine AS backend-builder
ARG BUILD_TAGS ARG BUILD_TAGS
WORKDIR /app/backend WORKDIR /app/backend
COPY ./backend/go.mod ./backend/go.sum ./ COPY ./backend/go.mod ./backend/go.sum ./

View File

@@ -1,6 +1,6 @@
module github.com/pocket-id/pocket-id/backend module github.com/pocket-id/pocket-id/backend
go 1.23.7 go 1.24
require ( require (
github.com/caarlos0/env/v11 v11.3.1 github.com/caarlos0/env/v11 v11.3.1

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
_ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
) )

View File

@@ -60,11 +60,7 @@ type AppConfigController struct {
// @Failure 500 {object} object "{"error": "error message"}" // @Failure 500 {object} object "{"error": "error message"}"
// @Router /application-configuration [get] // @Router /application-configuration [get]
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), false) configuration := acc.appConfigService.ListAppConfig(false)
if err != nil {
_ = c.Error(err)
return
}
var configVariablesDto []dto.PublicAppConfigVariableDto var configVariablesDto []dto.PublicAppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
@@ -85,11 +81,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/all [get] // @Router /application-configuration/all [get]
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), true) configuration := acc.appConfigService.ListAppConfig(true)
if err != nil {
_ = c.Error(err)
return
}
var configVariablesDto []dto.AppConfigVariableDto var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
@@ -143,17 +135,17 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
// @Success 200 {file} binary "Logo image" // @Success 200 {file} binary "Logo image"
// @Router /api/application-configuration/logo [get] // @Router /api/application-configuration/logo [get]
func (acc *AppConfigController) getLogoHandler(c *gin.Context) { func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
dbConfig := acc.appConfigService.GetDbConfig()
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true")) lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageName string var imageName, imageType string
var imageType string
if lightLogo { if lightLogo {
imageName = "logoLight" imageName = "logoLight"
imageType = acc.appConfigService.DbConfig.LogoLightImageType.Value imageType = dbConfig.LogoLightImageType.Value
} else { } else {
imageName = "logoDark" imageName = "logoDark"
imageType = acc.appConfigService.DbConfig.LogoDarkImageType.Value imageType = dbConfig.LogoDarkImageType.Value
} }
acc.getImage(c, imageName, imageType) acc.getImage(c, imageName, imageType)
@@ -181,7 +173,7 @@ func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
// @Failure 404 {object} object "{"error": "File not found"}" // @Failure 404 {object} object "{"error": "File not found"}"
// @Router /api/application-configuration/background-image [get] // @Router /api/application-configuration/background-image [get]
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
acc.getImage(c, "background", imageType) acc.getImage(c, "background", imageType)
} }
@@ -196,17 +188,17 @@ func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/application-configuration/logo [put] // @Router /api/application-configuration/logo [put]
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) { func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
dbConfig := acc.appConfigService.GetDbConfig()
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true")) lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageName string var imageName, imageType string
var imageType string
if lightLogo { if lightLogo {
imageName = "logoLight" imageName = "logoLight"
imageType = acc.appConfigService.DbConfig.LogoLightImageType.Value imageType = dbConfig.LogoLightImageType.Value
} else { } else {
imageName = "logoDark" imageName = "logoDark"
imageType = acc.appConfigService.DbConfig.LogoDarkImageType.Value imageType = dbConfig.LogoDarkImageType.Value
} }
acc.updateImage(c, imageName, imageType) acc.updateImage(c, imageName, imageType)
@@ -246,7 +238,7 @@ func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/application-configuration/background-image [put] // @Router /api/application-configuration/background-image [put]
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
acc.updateImage(c, "background", imageType) acc.updateImage(c, "background", imageType)
} }

View File

@@ -36,7 +36,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
return return
} }
if err := tc.TestService.ResetAppConfig(); err != nil { if err := tc.TestService.ResetAppConfig(c.Request.Context()); err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -227,7 +227,7 @@ func (uc *UserController) updateUserHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /api/users/me [put] // @Router /api/users/me [put]
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
if !uc.appConfigService.DbConfig.AllowOwnAccountEdit.IsTrue() { if !uc.appConfigService.GetDbConfig().AllowOwnAccountEdit.IsTrue() {
_ = c.Error(&common.AccountEditNotAllowedError{}) _ = c.Error(&common.AccountEditNotAllowedError{})
return return
} }
@@ -396,7 +396,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
return return
} }
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds()) maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)
@@ -421,7 +421,7 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
return return
} }
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds()) maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)

View File

@@ -106,7 +106,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
return return
} }
maxAge := int(wc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds()) maxAge := int(wc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)

View File

@@ -16,7 +16,7 @@ type AppConfigUpdateDto struct {
SessionDuration string `json:"sessionDuration" binding:"required"` SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"` EmailsVerified string `json:"emailsVerified" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"` AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
SmtHost string `json:"smtpHost"` SmtpHost string `json:"smtpHost"`
SmtpPort string `json:"smtpPort"` SmtpPort string `json:"smtpPort"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"` SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
SmtpUser string `json:"smtpUser"` SmtpUser string `json:"smtpUser"`

View File

@@ -34,7 +34,7 @@ func RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, app
} }
func (j *LdapJobs) syncLdap(ctx context.Context) error { func (j *LdapJobs) syncLdap(ctx context.Context) error {
if !j.appConfigService.DbConfig.LdapEnabled.IsTrue() { if !j.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return nil return nil
} }

View File

@@ -1,17 +1,17 @@
package model package model
import ( import (
"errors"
"fmt"
"reflect"
"strconv" "strconv"
"strings"
"time" "time"
) )
type AppConfigVariable struct { type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null"` Key string `gorm:"primaryKey;not null"`
Type string Value string
IsPublic bool
IsInternal bool
Value string
DefaultValue string
} }
// IsTrue returns true if the value is a truthy string, such as "true", "t", "yes", "1", etc. // IsTrue returns true if the value is a truthy string, such as "true", "t", "yes", "1", etc.
@@ -31,41 +31,153 @@ func (a *AppConfigVariable) AsDurationMinutes() time.Duration {
type AppConfig struct { type AppConfig struct {
// General // General
AppName AppConfigVariable AppName AppConfigVariable `key:"appName,public"` // Public
SessionDuration AppConfigVariable SessionDuration AppConfigVariable `key:"sessionDuration"`
EmailsVerified AppConfigVariable EmailsVerified AppConfigVariable `key:"emailsVerified"`
AllowOwnAccountEdit AppConfigVariable AllowOwnAccountEdit AppConfigVariable `key:"allowOwnAccountEdit,public"` // Public
// Internal // Internal
BackgroundImageType AppConfigVariable BackgroundImageType AppConfigVariable `key:"backgroundImageType,internal"` // Internal
LogoLightImageType AppConfigVariable LogoLightImageType AppConfigVariable `key:"logoLightImageType,internal"` // Internal
LogoDarkImageType AppConfigVariable LogoDarkImageType AppConfigVariable `key:"logoDarkImageType,internal"` // Internal
// Email // Email
SmtpHost AppConfigVariable SmtpHost AppConfigVariable `key:"smtpHost"`
SmtpPort AppConfigVariable SmtpPort AppConfigVariable `key:"smtpPort"`
SmtpFrom AppConfigVariable SmtpFrom AppConfigVariable `key:"smtpFrom"`
SmtpUser AppConfigVariable SmtpUser AppConfigVariable `key:"smtpUser"`
SmtpPassword AppConfigVariable SmtpPassword AppConfigVariable `key:"smtpPassword"`
SmtpTls AppConfigVariable SmtpTls AppConfigVariable `key:"smtpTls"`
SmtpSkipCertVerify AppConfigVariable SmtpSkipCertVerify AppConfigVariable `key:"smtpSkipCertVerify"`
EmailLoginNotificationEnabled AppConfigVariable EmailLoginNotificationEnabled AppConfigVariable `key:"emailLoginNotificationEnabled"`
EmailOneTimeAccessEnabled AppConfigVariable EmailOneTimeAccessEnabled AppConfigVariable `key:"emailOneTimeAccessEnabled,public"` // Public
// LDAP // LDAP
LdapEnabled AppConfigVariable LdapEnabled AppConfigVariable `key:"ldapEnabled,public"` // Public
LdapUrl AppConfigVariable LdapUrl AppConfigVariable `key:"ldapUrl"`
LdapBindDn AppConfigVariable LdapBindDn AppConfigVariable `key:"ldapBindDn"`
LdapBindPassword AppConfigVariable LdapBindPassword AppConfigVariable `key:"ldapBindPassword"`
LdapBase AppConfigVariable LdapBase AppConfigVariable `key:"ldapBase"`
LdapUserSearchFilter AppConfigVariable LdapUserSearchFilter AppConfigVariable `key:"ldapUserSearchFilter"`
LdapUserGroupSearchFilter AppConfigVariable LdapUserGroupSearchFilter AppConfigVariable `key:"ldapUserGroupSearchFilter"`
LdapSkipCertVerify AppConfigVariable LdapSkipCertVerify AppConfigVariable `key:"ldapSkipCertVerify"`
LdapAttributeUserUniqueIdentifier AppConfigVariable LdapAttributeUserUniqueIdentifier AppConfigVariable `key:"ldapAttributeUserUniqueIdentifier"`
LdapAttributeUserUsername AppConfigVariable LdapAttributeUserUsername AppConfigVariable `key:"ldapAttributeUserUsername"`
LdapAttributeUserEmail AppConfigVariable LdapAttributeUserEmail AppConfigVariable `key:"ldapAttributeUserEmail"`
LdapAttributeUserFirstName AppConfigVariable LdapAttributeUserFirstName AppConfigVariable `key:"ldapAttributeUserFirstName"`
LdapAttributeUserLastName AppConfigVariable LdapAttributeUserLastName AppConfigVariable `key:"ldapAttributeUserLastName"`
LdapAttributeUserProfilePicture AppConfigVariable LdapAttributeUserProfilePicture AppConfigVariable `key:"ldapAttributeUserProfilePicture"`
LdapAttributeGroupMember AppConfigVariable LdapAttributeGroupMember AppConfigVariable `key:"ldapAttributeGroupMember"`
LdapAttributeGroupUniqueIdentifier AppConfigVariable LdapAttributeGroupUniqueIdentifier AppConfigVariable `key:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeGroupName AppConfigVariable LdapAttributeGroupName AppConfigVariable `key:"ldapAttributeGroupName"`
LdapAttributeAdminGroup AppConfigVariable LdapAttributeAdminGroup AppConfigVariable `key:"ldapAttributeAdminGroup"`
}
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
// Use reflection to iterate through all fields
cfgValue := reflect.ValueOf(c).Elem()
cfgType := cfgValue.Type()
res := make([]AppConfigVariable, cfgType.NumField())
for i := range cfgType.NumField() {
field := cfgType.Field(i)
key, attrs, _ := strings.Cut(field.Tag.Get("key"), ",")
if key == "" {
continue
}
// If we're only showing public variables and this is not public, skip it
if !showAll && attrs != "public" {
continue
}
fieldValue := cfgValue.Field(i)
res[i] = AppConfigVariable{
Key: key,
Value: fieldValue.FieldByName("Value").String(),
}
}
return res
}
func (c *AppConfig) FieldByKey(key string) (string, error) {
rv := reflect.ValueOf(c).Elem()
rt := rv.Type()
// Find the field in the struct whose "key" tag matches
for i := range rt.NumField() {
// Grab only the first part of the key, if there's a comma with additional properties
tagValue, _, _ := strings.Cut(rt.Field(i).Tag.Get("key"), ",")
if tagValue != key {
continue
}
valueField := rv.Field(i).FieldByName("Value")
return valueField.String(), nil
}
// If we are here, the config key was not found
return "", AppConfigKeyNotFoundError{field: key}
}
func (c *AppConfig) UpdateField(key string, value string, noInternal bool) error {
rv := reflect.ValueOf(c).Elem()
rt := rv.Type()
// Find the field in the struct whose "key" tag matches, then update that
for i := range rt.NumField() {
// Separate the key (before the comma) from any optional attributes after
tagValue, attrs, _ := strings.Cut(rt.Field(i).Tag.Get("key"), ",")
if tagValue != key {
continue
}
// If the field is internal and noInternal is true, we skip that
if noInternal && attrs == "internal" {
return AppConfigInternalForbiddenError{field: key}
}
valueField := rv.Field(i).FieldByName("Value")
if !valueField.CanSet() {
return fmt.Errorf("field Value in AppConfigVariable is not settable for config key '%s'", key)
}
// Update the value
valueField.SetString(value)
// Return once updated
return nil
}
// If we're here, we have not found the right field to update
return AppConfigKeyNotFoundError{field: key}
}
type AppConfigKeyNotFoundError struct {
field string
}
func (e AppConfigKeyNotFoundError) Error() string {
return fmt.Sprintf("cannot find config key '%s'", e.field)
}
func (e AppConfigKeyNotFoundError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AppConfigKeyNotFoundError
x := AppConfigKeyNotFoundError{}
return errors.As(target, &x)
}
type AppConfigInternalForbiddenError struct {
field string
}
func (e AppConfigInternalForbiddenError) Error() string {
return fmt.Sprintf("field '%s' is internal and can't be updated", e.field)
}
func (e AppConfigInternalForbiddenError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AppConfigInternalForbiddenError
x := AppConfigInternalForbiddenError{}
return errors.As(target, &x)
} }

View File

@@ -1,10 +1,16 @@
package model // We use model_test here to avoid an import cycle
package model_test
import ( import (
"reflect"
"strings"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
) )
func TestAppConfigVariable_AsMinutesDuration(t *testing.T) { func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
@@ -48,7 +54,7 @@ func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
configVar := AppConfigVariable{ configVar := model.AppConfigVariable{
Value: tt.value, Value: tt.value,
} }
@@ -58,3 +64,66 @@ func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
}) })
} }
} }
// This test ensures that the model.AppConfig and dto.AppConfigUpdateDto structs match:
// - They should have the same properties, where the "json" tag of dto.AppConfigUpdateDto should match the "key" tag in model.AppConfig
// - dto.AppConfigDto should not include "internal" fields from model.AppConfig
// This test is primarily meant to catch discrepancies between the two structs as fields are added or removed over time
func TestAppConfigStructMatchesUpdateDto(t *testing.T) {
appConfigType := reflect.TypeOf(model.AppConfig{})
updateDtoType := reflect.TypeOf(dto.AppConfigUpdateDto{})
// Process AppConfig fields
appConfigFields := make(map[string]string)
for i := 0; i < appConfigType.NumField(); i++ {
field := appConfigType.Field(i)
if field.Tag.Get("key") == "" {
// Skip internal fields
continue
}
// Extract the key name from the tag (takes the part before any comma)
keyTag := field.Tag.Get("key")
keyName, _, _ := strings.Cut(keyTag, ",")
appConfigFields[field.Name] = keyName
}
// Process AppConfigUpdateDto fields
dtoFields := make(map[string]string)
for i := 0; i < updateDtoType.NumField(); i++ {
field := updateDtoType.Field(i)
// Extract the json name from the tag (takes the part before any binding constraints)
jsonTag := field.Tag.Get("json")
jsonName, _, _ := strings.Cut(jsonTag, ",")
dtoFields[jsonName] = field.Name
}
// Verify every AppConfig field has a matching DTO field with the same name
for fieldName, keyName := range appConfigFields {
if strings.HasSuffix(fieldName, "ImageType") {
// Skip internal fields that shouldn't be in the DTO
continue
}
// Check if there's a DTO field with a matching JSON tag
_, exists := dtoFields[keyName]
assert.True(t, exists, "Field %s with key '%s' in AppConfig has no matching field in AppConfigUpdateDto", fieldName, keyName)
}
// Verify every DTO field has a matching AppConfig field
for jsonName, fieldName := range dtoFields {
// Find a matching field in AppConfig by key tag
found := false
for _, keyName := range appConfigFields {
if keyName == jsonName {
found = true
break
}
}
assert.True(t, found, "Field %s with json tag '%s' in AppConfigUpdateDto has no matching field in AppConfig", fieldName, jsonName)
}
}

View File

@@ -3,30 +3,35 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"log" "log"
"mime/multipart" "mime/multipart"
"os" "os"
"reflect" "reflect"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
) )
type AppConfigService struct { type AppConfigService struct {
DbConfig *model.AppConfig dbConfig atomic.Pointer[model.AppConfig]
db *gorm.DB db *gorm.DB
} }
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService { func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{ service := &AppConfigService{
DbConfig: &defaultDbConfig, db: db,
db: db,
} }
err := service.InitDbConfig(ctx) err := service.LoadDbConfig(ctx)
if err != nil { if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err) log.Fatalf("Failed to initialize app config service: %v", err)
} }
@@ -34,170 +39,109 @@ func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
return service return service
} }
var defaultDbConfig = model.AppConfig{ // GetDbConfig returns the application configuration.
// General // Important: Treat the object as read-only: do not modify its properties directly!
AppName: model.AppConfigVariable{ func (s *AppConfigService) GetDbConfig() *model.AppConfig {
Key: "appName", v := s.dbConfig.Load()
Type: "string", if v == nil {
IsPublic: true, // This indicates a development-time error
DefaultValue: "Pocket ID", panic("called GetDbConfig before DbConfig is loaded")
}, }
SessionDuration: model.AppConfigVariable{
Key: "sessionDuration", return v
Type: "number", }
DefaultValue: "60",
}, func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
EmailsVerified: model.AppConfigVariable{ // Values are the default ones
Key: "emailsVerified", return &model.AppConfig{
Type: "bool", // General
DefaultValue: "false", AppName: model.AppConfigVariable{Value: "Pocket ID"},
}, SessionDuration: model.AppConfigVariable{Value: "60"},
AllowOwnAccountEdit: model.AppConfigVariable{ EmailsVerified: model.AppConfigVariable{Value: "false"},
Key: "allowOwnAccountEdit", AllowOwnAccountEdit: model.AppConfigVariable{Value: "true"},
Type: "bool", // Internal
IsPublic: true, BackgroundImageType: model.AppConfigVariable{Value: "jpg"},
DefaultValue: "true", LogoLightImageType: model.AppConfigVariable{Value: "svg"},
}, LogoDarkImageType: model.AppConfigVariable{Value: "svg"},
// Internal // Email
BackgroundImageType: model.AppConfigVariable{ SmtpHost: model.AppConfigVariable{},
Key: "backgroundImageType", SmtpPort: model.AppConfigVariable{},
Type: "string", SmtpFrom: model.AppConfigVariable{},
IsInternal: true, SmtpUser: model.AppConfigVariable{},
DefaultValue: "jpg", SmtpPassword: model.AppConfigVariable{},
}, SmtpTls: model.AppConfigVariable{Value: "none"},
LogoLightImageType: model.AppConfigVariable{ SmtpSkipCertVerify: model.AppConfigVariable{Value: "false"},
Key: "logoLightImageType", EmailLoginNotificationEnabled: model.AppConfigVariable{Value: "false"},
Type: "string", EmailOneTimeAccessEnabled: model.AppConfigVariable{Value: "false"},
IsInternal: true, // LDAP
DefaultValue: "svg", LdapEnabled: model.AppConfigVariable{Value: "false"},
}, LdapUrl: model.AppConfigVariable{},
LogoDarkImageType: model.AppConfigVariable{ LdapBindDn: model.AppConfigVariable{},
Key: "logoDarkImageType", LdapBindPassword: model.AppConfigVariable{},
Type: "string", LdapBase: model.AppConfigVariable{},
IsInternal: true, LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
DefaultValue: "svg", LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
}, LdapSkipCertVerify: model.AppConfigVariable{Value: "false"},
// Email LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{},
SmtpHost: model.AppConfigVariable{ LdapAttributeUserUsername: model.AppConfigVariable{},
Key: "smtpHost", LdapAttributeUserEmail: model.AppConfigVariable{},
Type: "string", LdapAttributeUserFirstName: model.AppConfigVariable{},
}, LdapAttributeUserLastName: model.AppConfigVariable{},
SmtpPort: model.AppConfigVariable{ LdapAttributeUserProfilePicture: model.AppConfigVariable{},
Key: "smtpPort", LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
Type: "number", LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
}, LdapAttributeGroupName: model.AppConfigVariable{},
SmtpFrom: model.AppConfigVariable{ LdapAttributeAdminGroup: model.AppConfigVariable{},
Key: "smtpFrom", }
Type: "string", }
},
SmtpUser: model.AppConfigVariable{ func (s *AppConfigService) updateAppConfigStartTransaction(ctx context.Context) (tx *gorm.DB, err error) {
Key: "smtpUser", // We start a transaction before doing any work, to ensure that we are the only ones updating the data in the database
Type: "string", // This works across multiple processes too
}, tx = s.db.Begin()
SmtpPassword: model.AppConfigVariable{ err = tx.Error
Key: "smtpPassword", if err != nil {
Type: "string", return nil, fmt.Errorf("failed to begin database transaction: %w", err)
}, }
SmtpTls: model.AppConfigVariable{
Key: "smtpTls", // With SQLite there's nothing else we need to do, because a transaction blocks the entire database
Type: "string", // However, with Postgres we need to manually lock the table to prevent others from doing the same
DefaultValue: "none", switch s.db.Name() {
}, case "postgres":
SmtpSkipCertVerify: model.AppConfigVariable{ // We do not use "NOWAIT" so this blocks until the database is available, or the context is canceled
Key: "smtpSkipCertVerify", // Here we use a context with a 10s timeout in case the database is blocked for longer
Type: "bool", lockCtx, lockCancel := context.WithTimeout(ctx, 10*time.Second)
DefaultValue: "false", defer lockCancel()
}, err = tx.
EmailLoginNotificationEnabled: model.AppConfigVariable{ WithContext(lockCtx).
Key: "emailLoginNotificationEnabled", Exec("LOCK TABLE app_config_variables IN ACCESS EXCLUSIVE MODE").
Type: "bool", Error
DefaultValue: "false", if err != nil {
}, tx.Rollback()
EmailOneTimeAccessEnabled: model.AppConfigVariable{ return nil, fmt.Errorf("failed to acquire lock on app_config_variables table: %w", err)
Key: "emailOneTimeAccessEnabled", }
Type: "bool", default:
IsPublic: true, // Nothing to do here
DefaultValue: "false", }
},
// LDAP return tx, nil
LdapEnabled: model.AppConfigVariable{ }
Key: "ldapEnabled",
Type: "bool", func (s *AppConfigService) updateAppConfigUpdateDatabase(ctx context.Context, tx *gorm.DB, dbUpdate *[]model.AppConfigVariable) error {
IsPublic: true, err := tx.
DefaultValue: "false", WithContext(ctx).
}, Clauses(clause.OnConflict{
LdapUrl: model.AppConfigVariable{ // Perform an "upsert" if the key already exists, replacing the value
Key: "ldapUrl", Columns: []clause.Column{{Name: "key"}},
Type: "string", DoUpdates: clause.AssignmentColumns([]string{"value"}),
}, }).
LdapBindDn: model.AppConfigVariable{ Create(&dbUpdate).
Key: "ldapBindDn", Error
Type: "string", if err != nil {
}, return fmt.Errorf("failed to update config in database: %w", err)
LdapBindPassword: model.AppConfigVariable{ }
Key: "ldapBindPassword",
Type: "string", return nil
},
LdapBase: model.AppConfigVariable{
Key: "ldapBase",
Type: "string",
},
LdapUserSearchFilter: model.AppConfigVariable{
Key: "ldapUserSearchFilter",
Type: "string",
DefaultValue: "(objectClass=person)",
},
LdapUserGroupSearchFilter: model.AppConfigVariable{
Key: "ldapUserGroupSearchFilter",
Type: "string",
DefaultValue: "(objectClass=groupOfNames)",
},
LdapSkipCertVerify: model.AppConfigVariable{
Key: "ldapSkipCertVerify",
Type: "bool",
DefaultValue: "false",
},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeUserUniqueIdentifier",
Type: "string",
},
LdapAttributeUserUsername: model.AppConfigVariable{
Key: "ldapAttributeUserUsername",
Type: "string",
},
LdapAttributeUserEmail: model.AppConfigVariable{
Key: "ldapAttributeUserEmail",
Type: "string",
},
LdapAttributeUserFirstName: model.AppConfigVariable{
Key: "ldapAttributeUserFirstName",
Type: "string",
},
LdapAttributeUserLastName: model.AppConfigVariable{
Key: "ldapAttributeUserLastName",
Type: "string",
},
LdapAttributeUserProfilePicture: model.AppConfigVariable{
Key: "ldapAttributeUserProfilePicture",
Type: "string",
},
LdapAttributeGroupMember: model.AppConfigVariable{
Key: "ldapAttributeGroupMember",
Type: "string",
DefaultValue: "member",
},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeGroupUniqueIdentifier",
Type: "string",
},
LdapAttributeGroupName: model.AppConfigVariable{
Key: "ldapAttributeGroupName",
Type: "string",
},
LdapAttributeAdminGroup: model.AppConfigVariable{
Key: "ldapAttributeAdminGroup",
Type: "string",
},
} }
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
@@ -205,106 +149,168 @@ func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppCon
return nil, &common.UiConfigDisabledError{} return nil, &common.UiConfigDisabledError{}
} }
tx := s.db.Begin() // If EmailLoginNotificationEnabled is set to false (explicitly), disable the EmailOneTimeAccessEnabled
if input.EmailLoginNotificationEnabled == "false" {
input.EmailOneTimeAccessEnabled = "false"
}
// Start the transaction
tx, err := s.updateAppConfigStartTransaction(ctx)
if err != nil {
return nil, err
}
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()
var err error // From here onwards, we know we are the only process/goroutine with exclusive access to the config
// Re-load the config from the database to be sure we have the correct data
cfg, err := s.loadDbConfigInternal(ctx, tx)
if err != nil {
return nil, fmt.Errorf("failed to reload config from database: %w", err)
}
defaultCfg := s.getDefaultDbConfig()
// Iterate through all the fields to update
// We update the in-memory data (in the cfg struct) and collect values to update in the database
rt := reflect.ValueOf(input).Type() rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input) rv := reflect.ValueOf(input)
dbUpdate := make([]model.AppConfigVariable, 0, rt.NumField())
savedConfigVariables := make([]model.AppConfigVariable, 0, rt.NumField())
for i := range rt.NumField() { for i := range rt.NumField() {
field := rt.Field(i) field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String() value := rv.FieldByName(field.Name).String()
// If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled // Get the value of the json tag, taking only what's before the comma
if key == s.DbConfig.EmailOneTimeAccessEnabled.Key { key, _, _ := strings.Cut(field.Tag.Get("json"), ",")
if rv.FieldByName("EmailEnabled").String() == "false" {
value = "false" // Update the in-memory config value
} // If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
// Ignore errors here as we know the key exists
defaultValue, _ := defaultCfg.FieldByKey(key)
err = cfg.UpdateField(key, defaultValue, true)
} else {
err = cfg.UpdateField(key, value, true)
} }
var appConfigVariable model.AppConfigVariable // If we tried to update an internal field, ignore the error (and do not update in the DB)
err = tx. if errors.Is(err, model.AppConfigInternalForbiddenError{}) {
WithContext(ctx). continue
First(&appConfigVariable, "key = ? AND is_internal = false", key). } else if err != nil {
Error return nil, fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
if err != nil {
return nil, err
} }
appConfigVariable.Value = value // We always save "value" which can be an empty string
err = tx. dbUpdate = append(dbUpdate, model.AppConfigVariable{
WithContext(ctx). Key: key,
Save(&appConfigVariable). Value: value,
Error })
if err != nil {
return nil, err
}
savedConfigVariables = append(savedConfigVariables, appConfigVariable)
} }
// Update the values in the database
err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
if err != nil {
return nil, err
}
// Commit the changes to the DB, then finally save the updated config in the object
err = tx.Commit().Error err = tx.Commit().Error
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to commit transaction: %w", err)
} }
err = s.LoadDbConfigFromDb() s.dbConfig.Store(cfg)
if err != nil {
return nil, err
}
return savedConfigVariables, nil // Return the updated config
res := cfg.ToAppConfigVariableSlice(true)
return res, nil
} }
func (s *AppConfigService) updateImageType(ctx context.Context, imageName string, fileType string) error { // UpdateAppConfigValues
key := imageName + "ImageType" func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndValues ...string) error {
err := s.db. if common.EnvConfig.UiConfigDisabled {
WithContext(ctx). return &common.UiConfigDisabledError{}
Model(&model.AppConfigVariable{}). }
Where("key = ?", key).
Update("value", fileType). // Count of keysAndValues must be even
Error if len(keysAndValues)%2 != 0 {
return errors.New("invalid number of arguments received")
}
// Start the transaction
tx, err := s.updateAppConfigStartTransaction(ctx)
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
// From here onwards, we know we are the only process/goroutine with exclusive access to the config
// Re-load the config from the database to be sure we have the correct data
cfg, err := s.loadDbConfigInternal(ctx, tx)
if err != nil {
return fmt.Errorf("failed to reload config from database: %w", err)
}
defaultCfg := s.getDefaultDbConfig()
// Iterate through all the fields to update
// We update the in-memory data (in the cfg struct) and collect values to update in the database
// (Note the += 2, as we are iterating through key-value pairs)
dbUpdate := make([]model.AppConfigVariable, 0, len(keysAndValues)/2)
for i := 0; i < len(keysAndValues); i += 2 {
key := keysAndValues[i]
value := keysAndValues[i+1]
// Ensure that the field is valid
// We do this by grabbing the default value
var defaultValue string
defaultValue, err = defaultCfg.FieldByKey(key)
if err != nil {
return fmt.Errorf("invalid configuration key '%s': %w", key, err)
}
// Update the in-memory config value
// If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
err = cfg.UpdateField(key, defaultValue, false)
} else {
err = cfg.UpdateField(key, value, false)
}
if err != nil {
return fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
}
// We always save "value" which can be an empty string
dbUpdate = append(dbUpdate, model.AppConfigVariable{
Key: key,
Value: value,
})
}
// Update the values in the database
err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
if err != nil { if err != nil {
return err return err
} }
return s.LoadDbConfigFromDb() // Commit the changes to the DB, then finally save the updated config in the object
err = tx.Commit().Error
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
s.dbConfig.Store(cfg)
return nil
} }
func (s *AppConfigService) ListAppConfig(ctx context.Context, showAll bool) (configuration []model.AppConfigVariable, err error) { func (s *AppConfigService) ListAppConfig(showAll bool) []model.AppConfigVariable {
if showAll { return s.GetDbConfig().ToAppConfigVariableSlice(showAll)
err = s.db.
WithContext(ctx).
Find(&configuration).
Error
} else {
err = s.db.
WithContext(ctx).
Find(&configuration, "is_public = true").
Error
}
if err != nil {
return nil, err
}
for i := range configuration {
if common.EnvConfig.UiConfigDisabled {
// Set the value to the environment variable if the UI config is disabled
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" {
// Set the value to the default value if it is empty
configuration[i].Value = configuration[i].DefaultValue
}
}
return configuration, nil
} }
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) { func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
@@ -314,161 +320,108 @@ func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multip
return &common.FileTypeNotSupportedError{} return &common.FileTypeNotSupportedError{}
} }
// Delete the old image if it has a different file type // Save the updated image
if fileType != oldImageType {
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
err = os.Remove(oldImagePath)
if err != nil {
return err
}
}
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
err = utils.SaveFile(uploadedFile, imagePath) err = utils.SaveFile(uploadedFile, imagePath)
if err != nil { if err != nil {
return err return err
} }
// Update the file type in the database // Delete the old image if it has a different file type, then update the type in the database
err = s.updateImageType(ctx, imageName, fileType) if fileType != oldImageType {
if err != nil { oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
return err err = os.Remove(oldImagePath)
}
return nil
}
// InitDbConfig creates the default configuration values in the database if they do not exist,
// updates existing configurations if they differ from the default, and deletes any configurations
// that are not in the default configuration.
func (s *AppConfigService) InitDbConfig(ctx context.Context) (err error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig
for i := range defaultConfigReflectValue.NumField() {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
defaultKeys[defaultConfigVar.Key] = struct{}{}
var storedConfigVar model.AppConfigVariable
err = tx.
WithContext(ctx).
First(&storedConfigVar, "key = ?", defaultConfigVar.Key).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the configuration does not exist, create it
err = tx.
WithContext(ctx).
Create(&defaultConfigVar).
Error
if err != nil {
return err
}
continue
} else if err != nil {
return err
}
// Update existing configuration if it differs from the default
if storedConfigVar.Type != defaultConfigVar.Type ||
storedConfigVar.IsPublic != defaultConfigVar.IsPublic ||
storedConfigVar.IsInternal != defaultConfigVar.IsInternal ||
storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
// Set values
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue
err = tx.
WithContext(ctx).
Save(&storedConfigVar).
Error
if err != nil {
return err
}
}
}
// Delete any configurations not in the default keys
var allConfigVars []model.AppConfigVariable
err = tx.
WithContext(ctx).
Find(&allConfigVars).
Error
if err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; exists {
continue
}
err = tx.
WithContext(ctx).
Delete(&config).
Error
if err != nil { if err != nil {
return err return err
} }
}
// Commit the changes // Update the file type in the database
err = tx.Commit().Error err = s.UpdateAppConfigValues(ctx, imageName+"ImageType", fileType)
if err != nil { if err != nil {
return err return err
} }
// Reload the configuration
err = s.LoadDbConfigFromDb()
if err != nil {
return err
} }
return nil return nil
} }
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct. // LoadDbConfig loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfigFromDb() error { func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
return s.db.Transaction(func(tx *gorm.DB) error { var dest *model.AppConfig
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
for i := range dbConfigReflectValue.NumField() { // If the UI config is disabled, only load from the env
dbConfigField := dbConfigReflectValue.Field(i) if common.EnvConfig.UiConfigDisabled {
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable) dest, err = s.loadDbConfigFromEnv()
var storedConfigVar model.AppConfigVariable } else {
err := tx.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error dest, err = s.loadDbConfigInternal(ctx, s.db)
if err != nil { }
return err if err != nil {
} return err
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
})
}
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {
environmentVariableName := utils.CamelCaseToScreamingSnakeCase(key)
if value, exists := os.LookupEnv(environmentVariableName); exists {
return value
} }
return fallbackValue // Update the value in the object
s.dbConfig.Store(dest)
return nil
}
func (s *AppConfigService) loadDbConfigFromEnv() (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Iterate through each field
rt := reflect.ValueOf(dest).Elem().Type()
rv := reflect.ValueOf(dest).Elem()
for i := range rt.NumField() {
field := rt.Field(i)
// Get the value of the key tag, taking only what's before the comma
// The env var name is the key converted to SCREAMING_SNAKE_CASE
key, _, _ := strings.Cut(field.Tag.Get("key"), ",")
envVarName := utils.CamelCaseToScreamingSnakeCase(key)
// Set the value if it's set
value, ok := os.LookupEnv(envVarName)
if ok {
rv.Field(i).FieldByName("Value").SetString(value)
}
}
return dest, nil
}
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Load all configuration values from the database
// This loads all values in a single shot
loaded := []model.AppConfigVariable{}
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
defer queryCancel()
err := tx.
WithContext(queryCtx).
Find(&loaded).Error
if err != nil {
return nil, fmt.Errorf("failed to load configuration from the database: %w", err)
}
// Iterate through all values loaded from the database
for _, v := range loaded {
// If the value is empty, it means we are using the default value
if v.Value == "" {
continue
}
// Find the field in the struct whose "key" tag matches, then update that
err = dest.UpdateField(v.Key, v.Value, false)
// We ignore the case of fields that don't exist, as there may be leftover data in the database
if err != nil && !errors.Is(err, model.AppConfigKeyNotFoundError{}) {
return nil, fmt.Errorf("failed to process config for key '%s': %w", v.Key, err)
}
}
return dest, nil
} }

View File

@@ -0,0 +1,561 @@
package service
import (
"sync/atomic"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/stretchr/testify/require"
)
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
service := &AppConfigService{
dbConfig: atomic.Pointer[model.AppConfig]{},
}
service.dbConfig.Store(config)
return service
}
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
// Load the config
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be equal to default config
require.Equal(t, service.GetDbConfig(), service.getDefaultDbConfig())
})
t.Run("loads value from config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
Create([]model.AppConfigVariable{
// Should be set to the default value because it's an empty string
{Key: "appName", Value: ""},
// Overrides default value
{Key: "sessionDuration", Value: "5"},
// Does not have a default value
{Key: "smtpHost", Value: "example"},
}).
Error
require.NoError(t, err)
// Load the config
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Values should match expected ones
expect := service.getDefaultDbConfig()
expect.SessionDuration.Value = "5"
expect.SmtpHost.Value = "example"
require.Equal(t, service.GetDbConfig(), expect)
})
t.Run("ignores unknown config keys", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
{Key: "__nonExistentKey", Value: "some value"},
{Key: "appName", Value: "TestApp"}, // This one should still be loaded
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// This should not fail, just ignore the unknown key
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
config := service.GetDbConfig()
require.Equal(t, "TestApp", config.AppName.Value)
})
t.Run("loading config multiple times", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "InitialApp"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "InitialApp", service.GetDbConfig().AppName.Value)
// Update the database value
err = db.Model(&model.AppConfigVariable{}).
Where("key = ?", "appName").
Update("value", "UpdatedApp").Error
require.NoError(t, err)
// Load the config again, it should reflect the updated value
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "UpdatedApp", service.GetDbConfig().AppName.Value)
})
t.Run("loads config from env when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables for testing
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Enable UiConfigDisabled to load from env
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from env, not DB
config := service.GetDbConfig()
require.Equal(t, "EnvTest App", config.AppName.Value, "Should load appName from env")
require.Equal(t, "45", config.SessionDuration.Value, "Should load sessionDuration from env")
})
t.Run("ignores env vars when UiConfigDisabled is false", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables that should be ignored
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Make sure UiConfigDisabled is false to load from DB
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from DB, not env
config := service.GetDbConfig()
require.Equal(t, "DB App", config.AppName.Value, "Should load appName from DB, not env")
require.Equal(t, "120", config.SessionDuration.Value, "Should load sessionDuration from DB, not env")
})
}
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update a single config value
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App")
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
// Verify database was updated
var dbValue model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&dbValue).Error
require.NoError(t, err)
require.Equal(t, "Test App", dbValue.Value)
})
t.Run("update multiple values", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update multiple config values
err = service.UpdateAppConfigValues(
t.Context(),
"appName", "Test App",
"sessionDuration", "30",
"smtpHost", "mail.example.com",
)
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
require.Equal(t, "30", config.SessionDuration.Value)
require.Equal(t, "mail.example.com", config.SmtpHost.Value)
// Verify database was updated
var count int64
db.Model(&model.AppConfigVariable{}).Count(&count)
require.Equal(t, int64(3), count)
var appName, sessionDuration, smtpHost model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Test App", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "30", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "mail.example.com", smtpHost.Value)
})
t.Run("empty value resets to default", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First change the value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "30")
require.NoError(t, err)
require.Equal(t, "30", service.GetDbConfig().SessionDuration.Value)
// Now set it to empty which should use default value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "")
require.NoError(t, err)
require.Equal(t, "60", service.GetDbConfig().SessionDuration.Value) // Default value from getDefaultDbConfig
})
t.Run("error with odd number of arguments", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with odd number of arguments
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App", "sessionDuration")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid number of arguments")
})
t.Run("error with invalid key", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with invalid key
err = service.UpdateAppConfigValues(t.Context(), "nonExistentKey", "some value")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid configuration key")
})
}
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Create update DTO
input := dto.AppConfigUpdateDto{
AppName: "Updated App Name",
SessionDuration: "120",
SmtpHost: "smtp.example.com",
SmtpPort: "587",
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables
require.NotEmpty(t, updatedVars)
var foundAppName, foundSessionDuration, foundSmtpHost, foundSmtpPort bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Updated App Name", v.Value)
foundAppName = true
case "sessionDuration":
require.Equal(t, "120", v.Value)
foundSessionDuration = true
case "smtpHost":
require.Equal(t, "smtp.example.com", v.Value)
foundSmtpHost = true
case "smtpPort":
require.Equal(t, "587", v.Value)
foundSmtpPort = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
require.True(t, foundSmtpHost)
require.True(t, foundSmtpPort)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Updated App Name", config.AppName.Value)
require.Equal(t, "120", config.SessionDuration.Value)
require.Equal(t, "smtp.example.com", config.SmtpHost.Value)
require.Equal(t, "587", config.SmtpPort.Value)
// Verify database was updated
var appName, sessionDuration, smtpHost, smtpPort model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Updated App Name", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "120", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "smtp.example.com", smtpHost.Value)
err = db.Where("key = ?", "smtpPort").First(&smtpPort).Error
require.NoError(t, err)
require.Equal(t, "587", smtpPort.Value)
})
t.Run("empty values reset to defaults", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First set some non-default values
err = service.UpdateAppConfigValues(t.Context(),
"appName", "Custom App",
"sessionDuration", "120",
)
require.NoError(t, err)
// Create update DTO with empty values to reset to defaults
input := dto.AppConfigUpdateDto{
AppName: "", // Should reset to default "Pocket ID"
SessionDuration: "", // Should reset to default "60"
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables (they should be empty strings in DB)
var foundAppName, foundSessionDuration bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Pocket ID", v.Value) // Returns the default value
foundAppName = true
case "sessionDuration":
require.Equal(t, "60", v.Value) // Returns the default value
foundSessionDuration = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
// Verify in-memory config was reset to defaults
config := service.GetDbConfig()
require.Equal(t, "Pocket ID", config.AppName.Value) // Default value
require.Equal(t, "60", config.SessionDuration.Value) // Default value
// Verify database was updated with empty values
for _, key := range []string{"appName", "sessionDuration"} {
var loaded model.AppConfigVariable
err = db.Where("key = ?", key).First(&loaded).Error
require.NoErrorf(t, err, "Failed to load DB value for key '%s'", key)
require.Emptyf(t, loaded.Value, "Loaded value for key '%s' is not empty", key)
}
})
t.Run("auto disables EmailOneTimeAccessEnabled when EmailLoginNotificationEnabled is false", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First enable both settings
err = service.UpdateAppConfigValues(t.Context(),
"emailLoginNotificationEnabled", "true",
"emailOneTimeAccessEnabled", "true",
)
require.NoError(t, err)
// Verify both are enabled
config := service.GetDbConfig()
require.True(t, config.EmailLoginNotificationEnabled.IsTrue())
require.True(t, config.EmailOneTimeAccessEnabled.IsTrue())
// Now disable EmailLoginNotificationEnabled
input := dto.AppConfigUpdateDto{
EmailLoginNotificationEnabled: "false",
// Don't set EmailOneTimeAccessEnabled, it should be auto-disabled
}
// Update config
_, err = service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify EmailOneTimeAccessEnabled was automatically disabled
config = service.GetDbConfig()
require.False(t, config.EmailLoginNotificationEnabled.IsTrue())
require.False(t, config.EmailOneTimeAccessEnabled.IsTrue())
})
t.Run("cannot update when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update config
_, err = service.UpdateAppConfig(t.Context(), dto.AppConfigUpdateDto{
AppName: "Should Not Update",
})
// Should get a UiConfigDisabledError
require.Error(t, err)
var uiConfigDisabledErr *common.UiConfigDisabledError
require.ErrorAs(t, err, &uiConfigDisabledErr)
})
}
// Implements gorm's logger.Writer interface
type testLoggerAdapter struct {
t *testing.T
}
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
dbName := utils.CreateSha256Hash(t.Name())
// Connect to a new in-memory SQL database
db, err := gorm.Open(
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
&gorm.Config{
TranslateError: true,
Logger: logger.New(
testLoggerAdapter{t: t},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
Colorful: false,
},
),
})
require.NoError(t, err, "Failed to connect to test database")
// Create the app_config_variables table
err = db.Exec(`
CREATE TABLE app_config_variables
(
key VARCHAR(100) NOT NULL PRIMARY KEY,
value TEXT NOT NULL
)
`).Error
require.NoError(t, err, "Failed to create test config table")
return db
}

View File

@@ -72,7 +72,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
} }
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email // If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 { if s.appConfigService.GetDbConfig().EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
// We use a background context here as this is running in a goroutine // We use a background context here as this is running in a goroutine
//nolint:contextcheck //nolint:contextcheck
go func() { go func() {

View File

@@ -300,19 +300,15 @@ func (s *TestService) ResetApplicationImages() error {
return nil return nil
} }
func (s *TestService) ResetAppConfig() error { func (s *TestService) ResetAppConfig(ctx context.Context) error {
// Reseed the config variables // Reset all app config variables to their default values in the database
if err := s.appConfigService.InitDbConfig(context.Background()); err != nil { err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error
return err if err != nil {
}
// Reset all app config variables to their default values
if err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error; err != nil {
return err return err
} }
// Reload the app config from the database after resetting the values // Reload the app config from the database after resetting the values
return s.appConfigService.LoadDbConfigFromDb() return s.appConfigService.LoadDbConfig(ctx)
} }
func (s *TestService) SetJWTKeys() { func (s *TestService) SetJWTKeys() {

View File

@@ -70,8 +70,10 @@ func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId stri
} }
func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error { func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
dbConfig := srv.appConfigService.GetDbConfig()
data := &email.TemplateData[V]{ data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value, AppName: dbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo", LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
Data: tData, Data: tData,
} }
@@ -86,8 +88,8 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
c.AddHeader("Subject", template.Title(data)) c.AddHeader("Subject", template.Title(data))
c.AddAddressHeader("From", []email.Address{ c.AddAddressHeader("From", []email.Address{
{ {
Email: srv.appConfigService.DbConfig.SmtpFrom.Value, Email: dbConfig.SmtpFrom.Value,
Name: srv.appConfigService.DbConfig.AppName.Value, Name: dbConfig.AppName.Value,
}, },
}) })
c.AddAddressHeader("To", []email.Address{toEmail}) c.AddAddressHeader("To", []email.Address{toEmail})
@@ -102,7 +104,7 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
// so we use the domain of the from address instead (the same as Thunderbird does) // so we use the domain of the from address instead (the same as Thunderbird does)
// if the address does not have an @ (which would be unusual), we use hostname // if the address does not have an @ (which would be unusual), we use hostname
from_address := srv.appConfigService.DbConfig.SmtpFrom.Value from_address := dbConfig.SmtpFrom.Value
domain := "" domain := ""
if strings.Contains(from_address, "@") { if strings.Contains(from_address, "@") {
domain = strings.Split(from_address, "@")[1] domain = strings.Split(from_address, "@")[1]
@@ -152,16 +154,18 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
} }
func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) { func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
port := srv.appConfigService.DbConfig.SmtpPort.Value dbConfig := srv.appConfigService.GetDbConfig()
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
port := dbConfig.SmtpPort.Value
smtpAddress := dbConfig.SmtpHost.Value + ":" + port
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec InsecureSkipVerify: dbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value, ServerName: dbConfig.SmtpHost.Value,
} }
// Connect to the SMTP server based on TLS setting // Connect to the SMTP server based on TLS setting
switch srv.appConfigService.DbConfig.SmtpTls.Value { switch dbConfig.SmtpTls.Value {
case "none": case "none":
client, err = smtp.Dial(smtpAddress) client, err = smtp.Dial(smtpAddress)
case "tls": case "tls":
@@ -172,7 +176,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
tlsConfig, tlsConfig,
) )
default: default:
return nil, fmt.Errorf("invalid SMTP TLS setting: %s", srv.appConfigService.DbConfig.SmtpTls.Value) return nil, fmt.Errorf("invalid SMTP TLS setting: %s", dbConfig.SmtpTls.Value)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err) return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
@@ -186,8 +190,8 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
} }
// Set up the authentication if user or password are set // Set up the authentication if user or password are set
smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value smtpUser := dbConfig.SmtpUser.Value
smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value smtpPassword := dbConfig.SmtpPassword.Value
if smtpUser != "" || smtpPassword != "" { if smtpUser != "" || smtpPassword != "" {
// Authenticate with plain auth // Authenticate with plain auth
@@ -223,7 +227,7 @@ func (srv *EmailService) sendHelloCommand(client *smtp.Client) error {
func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error { func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error {
// Set the sender // Set the sender
if err := client.Mail(srv.appConfigService.DbConfig.SmtpFrom.Value, nil); err != nil { if err := client.Mail(srv.appConfigService.GetDbConfig().SmtpFrom.Value, nil); err != nil {
return fmt.Errorf("failed to set sender: %w", err) return fmt.Errorf("failed to set sender: %w", err)
} }

View File

@@ -182,7 +182,7 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
now := time.Now() now := time.Now()
token, err := jwt.NewBuilder(). token, err := jwt.NewBuilder().
Subject(user.ID). Subject(user.ID).
Expiration(now.Add(s.appConfigService.DbConfig.SessionDuration.AsDurationMinutes())). Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
IssuedAt(now). IssuedAt(now).
Issuer(common.EnvConfig.AppURL). Issuer(common.EnvConfig.AppURL).
Build() Build()

View File

@@ -26,11 +26,9 @@ import (
) )
func TestJwtService_Init(t *testing.T) { func TestJwtService_Init(t *testing.T) {
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes })
},
}
t.Run("should generate new key when none exists", func(t *testing.T) { t.Run("should generate new key when none exists", func(t *testing.T) {
// Create a temporary directory for the test // Create a temporary directory for the test
@@ -142,11 +140,9 @@ func TestJwtService_Init(t *testing.T) {
} }
func TestJwtService_GetPublicJWK(t *testing.T) { func TestJwtService_GetPublicJWK(t *testing.T) {
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes })
},
}
t.Run("returns public key when private key is initialized", func(t *testing.T) { t.Run("returns public key when private key is initialized", func(t *testing.T) {
// Create a temporary directory for the test // Create a temporary directory for the test
@@ -276,11 +272,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService // Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes })
},
}
// Setup the environment variable required by the token verification // Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL originalAppURL := common.EnvConfig.AppURL
@@ -366,11 +360,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
t.Run("uses session duration from config", func(t *testing.T) { t.Run("uses session duration from config", func(t *testing.T) {
// Create a JWT service with a different session duration // Create a JWT service with a different session duration
customMockConfig := &AppConfigService{ customMockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes
SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes })
},
}
service := &JwtService{} service := &JwtService{}
err := service.init(customMockConfig, tempDir) err := service.init(customMockConfig, tempDir)
@@ -567,11 +559,9 @@ func TestGenerateVerifyIdToken(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService // Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes })
},
}
// Setup the environment variable required by the token verification // Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL originalAppURL := common.EnvConfig.AppURL
@@ -900,11 +890,9 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService // Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes })
},
}
// Setup the environment variable required by the token verification // Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL originalAppURL := common.EnvConfig.AppURL

View File

@@ -32,12 +32,15 @@ func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService
} }
func (s *LdapService) createClient() (*ldap.Conn, error) { func (s *LdapService) createClient() (*ldap.Conn, error) {
if !s.appConfigService.DbConfig.LdapEnabled.IsTrue() { dbConfig := s.appConfigService.GetDbConfig()
if !dbConfig.LdapEnabled.IsTrue() {
return nil, fmt.Errorf("LDAP is not enabled") return nil, fmt.Errorf("LDAP is not enabled")
} }
// Setup LDAP connection // Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value ldapURL := dbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue() skipTLSVerify := dbConfig.LdapSkipCertVerify.IsTrue()
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{ client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: skipTLSVerify, //nolint:gosec InsecureSkipVerify: skipTLSVerify, //nolint:gosec
})) }))
@@ -46,8 +49,8 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
} }
// Bind as service account // Bind as service account
bindDn := s.appConfigService.DbConfig.LdapBindDn.Value bindDn := dbConfig.LdapBindDn.Value
bindPassword := s.appConfigService.DbConfig.LdapBindPassword.Value bindPassword := dbConfig.LdapBindPassword.Value
err = client.Bind(bindDn, bindPassword) err = client.Bind(bindDn, bindPassword)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to bind to LDAP: %w", err) return nil, fmt.Errorf("failed to bind to LDAP: %w", err)
@@ -83,6 +86,8 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
//nolint:gocognit //nolint:gocognit
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error { func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
dbConfig := s.appConfigService.GetDbConfig()
// Setup LDAP connection // Setup LDAP connection
client, err := s.createClient() client, err := s.createClient()
if err != nil { if err != nil {
@@ -90,19 +95,20 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
} }
defer client.Close() defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value
nameAttribute := s.appConfigService.DbConfig.LdapAttributeGroupName.Value
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeGroupUniqueIdentifier.Value
groupMemberOfAttribute := s.appConfigService.DbConfig.LdapAttributeGroupMember.Value
filter := s.appConfigService.DbConfig.LdapUserGroupSearchFilter.Value
searchAttrs := []string{ searchAttrs := []string{
nameAttribute, dbConfig.LdapAttributeGroupName.Value,
uniqueIdentifierAttribute, dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
groupMemberOfAttribute, dbConfig.LdapAttributeGroupMember.Value,
} }
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{}) searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserGroupSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq) result, err := client.Search(searchReq)
if err != nil { if err != nil {
return fmt.Errorf("failed to query LDAP: %w", err) return fmt.Errorf("failed to query LDAP: %w", err)
@@ -114,11 +120,11 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
for _, value := range result.Entries { for _, value := range result.Entries {
var membersUserId []string var membersUserId []string
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute) ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
// Skip groups without a valid LDAP ID // Skip groups without a valid LDAP ID
if ldapId == "" { if ldapId == "" {
log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute) log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
continue continue
} }
@@ -129,7 +135,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup) tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
// Get group members and add to the correct Group // Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(groupMemberOfAttribute) groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
for _, member := range groupMembers { for _, member := range groupMembers {
// Normal output of this would be CN=username,ou=people,dc=example,dc=com // Normal output of this would be CN=username,ou=people,dc=example,dc=com
// Splitting at the "=" and "," then just grabbing the username for that string // Splitting at the "=" and "," then just grabbing the username for that string
@@ -151,9 +157,9 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
} }
syncGroup := dto.UserGroupCreateDto{ syncGroup := dto.UserGroupCreateDto{
Name: value.GetAttributeValue(nameAttribute), Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
FriendlyName: value.GetAttributeValue(nameAttribute), FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
LdapID: value.GetAttributeValue(uniqueIdentifierAttribute), LdapID: value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value),
} }
if databaseGroup.ID == "" { if databaseGroup.ID == "" {
@@ -214,6 +220,8 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
//nolint:gocognit //nolint:gocognit
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error { func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
dbConfig := s.appConfigService.GetDbConfig()
// Setup LDAP connection // Setup LDAP connection
client, err := s.createClient() client, err := s.createClient()
if err != nil { if err != nil {
@@ -221,30 +229,27 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
} }
defer client.Close() defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeUserUniqueIdentifier.Value
usernameAttribute := s.appConfigService.DbConfig.LdapAttributeUserUsername.Value
emailAttribute := s.appConfigService.DbConfig.LdapAttributeUserEmail.Value
firstNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserFirstName.Value
lastNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserLastName.Value
profilePictureAttribute := s.appConfigService.DbConfig.LdapAttributeUserProfilePicture.Value
adminGroupAttribute := s.appConfigService.DbConfig.LdapAttributeAdminGroup.Value
filter := s.appConfigService.DbConfig.LdapUserSearchFilter.Value
searchAttrs := []string{ searchAttrs := []string{
"memberOf", "memberOf",
"sn", "sn",
"cn", "cn",
uniqueIdentifierAttribute, dbConfig.LdapAttributeUserUniqueIdentifier.Value,
usernameAttribute, dbConfig.LdapAttributeUserUsername.Value,
emailAttribute, dbConfig.LdapAttributeUserEmail.Value,
firstNameAttribute, dbConfig.LdapAttributeUserFirstName.Value,
lastNameAttribute, dbConfig.LdapAttributeUserLastName.Value,
profilePictureAttribute, dbConfig.LdapAttributeUserProfilePicture.Value,
} }
// Filters must start and finish with ()! // Filters must start and finish with ()!
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{}) searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq) result, err := client.Search(searchReq)
if err != nil { if err != nil {
@@ -255,11 +260,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
ldapUserIDs := make(map[string]bool) ldapUserIDs := make(map[string]bool)
for _, value := range result.Entries { for _, value := range result.Entries {
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute) ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
// Skip users without a valid LDAP ID // Skip users without a valid LDAP ID
if ldapId == "" { if ldapId == "" {
log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute) log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeUserUniqueIdentifier.Value)
continue continue
} }
@@ -272,16 +277,16 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
// Check if user is admin by checking if they are in the admin group // Check if user is admin by checking if they are in the admin group
isAdmin := false isAdmin := false
for _, group := range value.GetAttributeValues("memberOf") { for _, group := range value.GetAttributeValues("memberOf") {
if strings.Contains(group, adminGroupAttribute) { if strings.Contains(group, dbConfig.LdapAttributeAdminGroup.Value) {
isAdmin = true isAdmin = true
} }
} }
newUser := dto.UserCreateDto{ newUser := dto.UserCreateDto{
Username: value.GetAttributeValue(usernameAttribute), Username: value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value),
Email: value.GetAttributeValue(emailAttribute), Email: value.GetAttributeValue(dbConfig.LdapAttributeUserEmail.Value),
FirstName: value.GetAttributeValue(firstNameAttribute), FirstName: value.GetAttributeValue(dbConfig.LdapAttributeUserFirstName.Value),
LastName: value.GetAttributeValue(lastNameAttribute), LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
IsAdmin: isAdmin, IsAdmin: isAdmin,
LdapID: ldapId, LdapID: ldapId,
} }
@@ -299,7 +304,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
} }
// Save profile picture // Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" { if pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value); pictureString != "" {
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil { if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err) log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
} }

View File

@@ -748,7 +748,7 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
if slices.Contains(scopes, "email") { if slices.Contains(scopes, "email") {
claims["email"] = user.Email claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue() claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
} }
if slices.Contains(scopes, "groups") { if slices.Contains(scopes, "groups") {

View File

@@ -79,7 +79,7 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error {
} }
// Disallow deleting the group if it is an LDAP group and LDAP is enabled // Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { if group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserGroupUpdateError{} return &common.LdapUserGroupUpdateError{}
} }
@@ -148,7 +148,7 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
} }
// Disallow updating the group if it is an LDAP group and LDAP is enabled // Disallow updating the group if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.UserGroup{}, &common.LdapUserGroupUpdateError{} return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
} }

View File

@@ -188,7 +188,7 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
} }
// Disallow deleting the user if it is an LDAP user and LDAP is enabled // Disallow deleting the user if it is an LDAP user and LDAP is enabled
if !allowLdapDelete && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { if !allowLdapDelete && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserUpdateError{} return &common.LdapUserUpdateError{}
} }
@@ -278,7 +278,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
} }
// Disallow updating the user if it is an LDAP group and LDAP is enabled // Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{} return model.User{}, &common.LdapUserUpdateError{}
} }
@@ -314,7 +314,7 @@ func (s *UserService) RequestOneTimeAccessEmail(ctx context.Context, emailAddres
tx.Rollback() tx.Rollback()
}() }()
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue() isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessEnabled.IsTrue()
if isDisabled { if isDisabled {
return &common.OneTimeAccessDisabledError{} return &common.OneTimeAccessDisabledError{}
} }

View File

@@ -26,7 +26,7 @@ type WebAuthnService struct {
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService { func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService {
webauthnConfig := &webauthn.Config{ webauthnConfig := &webauthn.Config{
RPDisplayName: appConfigService.DbConfig.AppName.Value, RPDisplayName: appConfigService.GetDbConfig().AppName.Value,
RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL), RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL),
RPOrigins: []string{common.EnvConfig.AppURL}, RPOrigins: []string{common.EnvConfig.AppURL},
Timeouts: webauthn.TimeoutsConfig{ Timeouts: webauthn.TimeoutsConfig{
@@ -43,7 +43,13 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
}, },
} }
wa, _ := webauthn.New(webauthnConfig) wa, _ := webauthn.New(webauthnConfig)
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService} return &WebAuthnService{
db: db,
webAuthn: wa,
jwtService: jwtService,
auditLogService: auditLogService,
appConfigService: appConfigService,
}
} }
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) { func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
@@ -314,5 +320,5 @@ func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credenti
// updateWebAuthnConfig updates the WebAuthn configuration with the app name as it can change during runtime // updateWebAuthnConfig updates the WebAuthn configuration with the app name as it can change during runtime
func (s *WebAuthnService) updateWebAuthnConfig() { func (s *WebAuthnService) updateWebAuthnConfig() {
s.webAuthn.Config.RPDisplayName = s.appConfigService.DbConfig.AppName.Value s.webAuthn.Config.RPDisplayName = s.appConfigService.GetDbConfig().AppName.Value
} }

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables ADD type VARCHAR(20) NOT NULL,;
ALTER TABLE app_config_variables ADD is_public BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD is_internal BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD default_value TEXT;

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables DROP type;
ALTER TABLE app_config_variables DROP is_public;
ALTER TABLE app_config_variables DROP is_internal;
ALTER TABLE app_config_variables DROP default_value;

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables ADD type VARCHAR(20) NOT NULL,;
ALTER TABLE app_config_variables ADD is_public BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD is_internal BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD default_value TEXT;

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables DROP type;
ALTER TABLE app_config_variables DROP is_public;
ALTER TABLE app_config_variables DROP is_internal;
ALTER TABLE app_config_variables DROP default_value;