diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 9e03be03..eed165f7 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -1,6 +1,8 @@ package main import ( + "log" + "github.com/pocket-id/pocket-id/backend/internal/bootstrap" ) @@ -9,5 +11,8 @@ import ( // @description.markdown func main() { - bootstrap.Bootstrap() + err := bootstrap.Bootstrap() + if err != nil { + log.Fatal(err.Error()) + } } diff --git a/backend/internal/bootstrap/bootstrap.go b/backend/internal/bootstrap/bootstrap.go index 9753d6f2..d07dd9db 100644 --- a/backend/internal/bootstrap/bootstrap.go +++ b/backend/internal/bootstrap/bootstrap.go @@ -2,25 +2,69 @@ package bootstrap import ( "context" + "fmt" + "log" + "time" _ "github.com/golang-migrate/migrate/v4/source/file" - "github.com/pocket-id/pocket-id/backend/internal/service" + "github.com/pocket-id/pocket-id/backend/internal/job" + "github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils/signals" ) -func Bootstrap() { +func Bootstrap() error { // Get a context that is canceled when the application is stopping ctx := signals.SignalContext(context.Background()) initApplicationImages() + // Perform migrations for changes migrateConfigDBConnstring() - - db := newDatabase() - appConfigService := service.NewAppConfigService(ctx, db) - migrateKey() - initRouter(ctx, db, appConfigService) + // Connect to the database + db := newDatabase() + + // Create all services + svc, err := initServices(ctx, db) + if err != nil { + return fmt.Errorf("failed to initialize services: %w", err) + } + + // Init the job scheduler + scheduler, err := job.NewScheduler() + if err != nil { + return fmt.Errorf("failed to create job scheduler: %w", err) + } + err = registerScheduledJobs(ctx, db, svc, scheduler) + if err != nil { + return fmt.Errorf("failed to register scheduled jobs: %w", err) + } + + // Init the router + router := initRouter(db, svc) + + // Run all background serivces + // This call blocks until the context is canceled + err = utils. + NewServiceRunner(router, scheduler.Run). + Run(ctx) + if err != nil { + return fmt.Errorf("failed to run services: %w", err) + } + + // Invoke all shutdown functions + // We give these a timeout of 5s + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + err = utils. + // TODO: Add shutdown services here + NewServiceRunner(). + Run(shutdownCtx) + if err != nil { + log.Printf("Error shutting down services: %v", err) + } + + return nil } diff --git a/backend/internal/bootstrap/e2etest_router_bootstrap.go b/backend/internal/bootstrap/e2etest_router_bootstrap.go index 7ee8e1e0..226a1b62 100644 --- a/backend/internal/bootstrap/e2etest_router_bootstrap.go +++ b/backend/internal/bootstrap/e2etest_router_bootstrap.go @@ -12,9 +12,9 @@ import ( // When building for E2E tests, add the e2etest controller func init() { - registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService){ - func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService) { - testService := service.NewTestService(db, appConfigService, jwtService) + registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){ + func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) { + testService := service.NewTestService(db, svc.appConfigService, svc.jwtService) controller.NewTestController(apiGroup, testService) }, } diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 5f40c5e2..e7469681 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -9,27 +9,28 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/pocket-id/pocket-id/backend/internal/common" - "github.com/pocket-id/pocket-id/backend/internal/controller" - "github.com/pocket-id/pocket-id/backend/internal/job" - "github.com/pocket-id/pocket-id/backend/internal/middleware" - "github.com/pocket-id/pocket-id/backend/internal/service" - "github.com/pocket-id/pocket-id/backend/internal/utils/systemd" "golang.org/x/time/rate" "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/controller" + "github.com/pocket-id/pocket-id/backend/internal/middleware" + "github.com/pocket-id/pocket-id/backend/internal/utils" + "github.com/pocket-id/pocket-id/backend/internal/utils/systemd" ) // This is used to register additional controllers for tests -var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService) +var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) -func initRouter(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) { - err := initRouterInternal(ctx, db, appConfigService) +func initRouter(db *gorm.DB, svc *services) utils.Service { + runner, err := initRouterInternal(db, svc) if err != nil { log.Fatalf("failed to init router: %v", err) } + return runner } -func initRouterInternal(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) error { +func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) { // Set the appropriate Gin mode based on the environment switch common.EnvConfig.AppEnv { case "production": @@ -43,23 +44,6 @@ func initRouterInternal(ctx context.Context, db *gorm.DB, appConfigService *serv r := gin.Default() r.Use(gin.Logger()) - // Initialize services - emailService, err := service.NewEmailService(appConfigService, db) - if err != nil { - return fmt.Errorf("unable to create email service: %w", err) - } - - geoLiteService := service.NewGeoLiteService(ctx) - auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService) - jwtService := service.NewJwtService(appConfigService) - webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService) - userService := service.NewUserService(db, jwtService, auditLogService, emailService, appConfigService) - customClaimService := service.NewCustomClaimService(db) - oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService) - userGroupService := service.NewUserGroupService(db, appConfigService) - ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService) - apiKeyService := service.NewApiKeyService(db, emailService) - rateLimitMiddleware := middleware.NewRateLimitMiddleware() // Setup global middleware @@ -67,56 +51,31 @@ func initRouterInternal(ctx context.Context, db *gorm.DB, appConfigService *serv r.Use(middleware.NewErrorHandlerMiddleware().Add()) r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60)) - scheduler, err := job.NewScheduler() - if err != nil { - return fmt.Errorf("failed to create job scheduler: %w", err) - } - - err = scheduler.RegisterLdapJobs(ctx, ldapService, appConfigService) - if err != nil { - return fmt.Errorf("failed to register LDAP jobs in scheduler: %w", err) - } - err = scheduler.RegisterDbCleanupJobs(ctx, db) - if err != nil { - return fmt.Errorf("failed to register DB cleanup jobs in scheduler: %w", err) - } - err = scheduler.RegisterFileCleanupJobs(ctx, db) - if err != nil { - return fmt.Errorf("failed to register file cleanup jobs in scheduler: %w", err) - } - err = scheduler.RegisterApiKeyExpiryJob(ctx, apiKeyService, appConfigService) - if err != nil { - return fmt.Errorf("failed to register API key expiration jobs in scheduler: %w", err) - } - - // Run the scheduler in a background goroutine, until the context is canceled - go scheduler.Run(ctx) - // Initialize middleware for specific routes - authMiddleware := middleware.NewAuthMiddleware(apiKeyService, userService, jwtService) + authMiddleware := middleware.NewAuthMiddleware(svc.apiKeyService, svc.userService, svc.jwtService) fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() // Set up API routes apiGroup := r.Group("/api") - controller.NewApiKeyController(apiGroup, authMiddleware, apiKeyService) - controller.NewWebauthnController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, appConfigService) - controller.NewOidcController(apiGroup, authMiddleware, fileSizeLimitMiddleware, oidcService, jwtService) - controller.NewUserController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), userService, appConfigService) - controller.NewAppConfigController(apiGroup, authMiddleware, appConfigService, emailService, ldapService) - controller.NewAuditLogController(apiGroup, auditLogService, authMiddleware) - controller.NewUserGroupController(apiGroup, authMiddleware, userGroupService) - controller.NewCustomClaimController(apiGroup, authMiddleware, customClaimService) + controller.NewApiKeyController(apiGroup, authMiddleware, svc.apiKeyService) + controller.NewWebauthnController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), svc.webauthnService, svc.appConfigService) + controller.NewOidcController(apiGroup, authMiddleware, fileSizeLimitMiddleware, svc.oidcService, svc.jwtService) + controller.NewUserController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), svc.userService, svc.appConfigService) + controller.NewAppConfigController(apiGroup, authMiddleware, svc.appConfigService, svc.emailService, svc.ldapService) + controller.NewAuditLogController(apiGroup, svc.auditLogService, authMiddleware) + controller.NewUserGroupController(apiGroup, authMiddleware, svc.userGroupService) + controller.NewCustomClaimController(apiGroup, authMiddleware, svc.customClaimService) // Add test controller in non-production environments if common.EnvConfig.AppEnv != "production" { for _, f := range registerTestControllers { - f(apiGroup, db, appConfigService, jwtService) + f(apiGroup, db, svc) } } // Set up base routes baseGroup := r.Group("/") - controller.NewWellKnownController(baseGroup, jwtService) + controller.NewWellKnownController(baseGroup, svc.jwtService) // Set up the server srv := &http.Server{ @@ -129,41 +88,46 @@ func initRouterInternal(ctx context.Context, db *gorm.DB, appConfigService *serv // Set up the listener listener, err := net.Listen("tcp", srv.Addr) if err != nil { - return fmt.Errorf("failed to create TCP listener: %w", err) + return nil, fmt.Errorf("failed to create TCP listener: %w", err) } - log.Printf("Server listening on %s", srv.Addr) + // Service runner function + runFn := func(ctx context.Context) error { + log.Printf("Server listening on %s", srv.Addr) - // Notify systemd that we are ready - err = systemd.SdNotifyReady() - if err != nil { - log.Printf("[WARN] Unable to notify systemd that the service is ready: %v", err) - // continue to serve anyway since it's not that important - } + // Start the server in a background goroutine + go func() { + defer listener.Close() - // Start the server in a background goroutine - go func() { - defer listener.Close() + // Next call blocks until the server is shut down + srvErr := srv.Serve(listener) + if srvErr != http.ErrServerClosed { + log.Fatalf("Error starting app server: %v", srvErr) + } + }() - // Next call blocks until the server is shut down - srvErr := srv.Serve(listener) - if srvErr != http.ErrServerClosed { - log.Fatalf("Error starting app server: %v", srvErr) + // Notify systemd that we are ready + err = systemd.SdNotifyReady() + if err != nil { + // Log the error only + log.Printf("[WARN] Unable to notify systemd that the service is ready: %v", err) } - }() - // Block until the context is canceled - <-ctx.Done() + // Block until the context is canceled + <-ctx.Done() - // Handle graceful shutdown - // Note we use the background context here as ctx has been canceled already - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) - shutdownErr := srv.Shutdown(shutdownCtx) //nolint:contextcheck - shutdownCancel() - if shutdownErr != nil { - // Log the error only (could be context canceled) - log.Printf("[WARN] App server shutdown error: %v", shutdownErr) + // Handle graceful shutdown + // Note we use the background context here as ctx has been canceled already + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + shutdownErr := srv.Shutdown(shutdownCtx) //nolint:contextcheck + shutdownCancel() + if shutdownErr != nil { + // Log the error only (could be context canceled) + log.Printf("[WARN] App server shutdown error: %v", shutdownErr) + } + + return nil } - return nil + return runFn, nil } diff --git a/backend/internal/bootstrap/scheduler_bootstrap.go b/backend/internal/bootstrap/scheduler_bootstrap.go new file mode 100644 index 00000000..c6635486 --- /dev/null +++ b/backend/internal/bootstrap/scheduler_bootstrap.go @@ -0,0 +1,35 @@ +package bootstrap + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/job" +) + +func registerScheduledJobs(ctx context.Context, db *gorm.DB, svc *services, scheduler *job.Scheduler) error { + err := scheduler.RegisterLdapJobs(ctx, svc.ldapService, svc.appConfigService) + if err != nil { + return fmt.Errorf("failed to register LDAP jobs in scheduler: %w", err) + } + err = scheduler.RegisterGeoLiteUpdateJobs(ctx, svc.geoLiteService) + if err != nil { + return fmt.Errorf("failed to register GeoLite DB update service: %w", err) + } + err = scheduler.RegisterDbCleanupJobs(ctx, db) + if err != nil { + return fmt.Errorf("failed to register DB cleanup jobs in scheduler: %w", err) + } + err = scheduler.RegisterFileCleanupJobs(ctx, db) + if err != nil { + return fmt.Errorf("failed to register file cleanup jobs in scheduler: %w", err) + } + err = scheduler.RegisterApiKeyExpiryJob(ctx, svc.apiKeyService, svc.appConfigService) + if err != nil { + return fmt.Errorf("failed to register API key expiration jobs in scheduler: %w", err) + } + + return nil +} diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go new file mode 100644 index 00000000..d4b19433 --- /dev/null +++ b/backend/internal/bootstrap/services_bootstrap.go @@ -0,0 +1,51 @@ +package bootstrap + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/service" +) + +type services struct { + appConfigService *service.AppConfigService + emailService *service.EmailService + geoLiteService *service.GeoLiteService + auditLogService *service.AuditLogService + jwtService *service.JwtService + webauthnService *service.WebAuthnService + userService *service.UserService + customClaimService *service.CustomClaimService + oidcService *service.OidcService + userGroupService *service.UserGroupService + ldapService *service.LdapService + apiKeyService *service.ApiKeyService +} + +// Initializes all services +// The context should be used by services only for initialization, and not for running +func initServices(initCtx context.Context, db *gorm.DB) (svc *services, err error) { + svc = &services{} + + svc.appConfigService = service.NewAppConfigService(initCtx, db) + + svc.emailService, err = service.NewEmailService(db, svc.appConfigService) + if err != nil { + return nil, fmt.Errorf("unable to create email service: %w", err) + } + + svc.geoLiteService = service.NewGeoLiteService() + svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService) + svc.jwtService = service.NewJwtService(svc.appConfigService) + svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService) + svc.customClaimService = service.NewCustomClaimService(db) + svc.oidcService = service.NewOidcService(db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService) + svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService) + svc.ldapService = service.NewLdapService(db, svc.appConfigService, svc.userService, svc.userGroupService) + svc.apiKeyService = service.NewApiKeyService(db, svc.emailService) + svc.webauthnService = service.NewWebAuthnService(db, svc.jwtService, svc.auditLogService, svc.appConfigService) + + return svc, nil +} diff --git a/backend/internal/job/geoloite_update_job.go b/backend/internal/job/geoloite_update_job.go new file mode 100644 index 00000000..0ad31f82 --- /dev/null +++ b/backend/internal/job/geoloite_update_job.go @@ -0,0 +1,45 @@ +package job + +import ( + "context" + "log" + "time" + + "github.com/pocket-id/pocket-id/backend/internal/service" +) + +type GeoLiteUpdateJobs struct { + geoLiteService *service.GeoLiteService +} + +func (s *Scheduler) RegisterGeoLiteUpdateJobs(ctx context.Context, geoLiteService *service.GeoLiteService) error { + // Check if the service needs periodic updating + if geoLiteService.DisableUpdater() { + // Nothing to do + return nil + } + + jobs := &GeoLiteUpdateJobs{geoLiteService: geoLiteService} + + // Register the job to run every day, at 5 minutes past midnight + err := s.registerJob(ctx, "UpdateGeoLiteDB", "5 * */1 * *", jobs.updateGoeLiteDB) + if err != nil { + return err + } + + // Run the job immediately on startup, with a 1s delay + go func() { + time.Sleep(time.Second) + err = jobs.updateGoeLiteDB(ctx) + if err != nil { + // Log the error only, but don't return it + log.Printf("Failed to Update GeoLite database: %v", err) + } + }() + + return nil +} + +func (j *GeoLiteUpdateJobs) updateGoeLiteDB(ctx context.Context) error { + return j.geoLiteService.UpdateDatabase(ctx) +} diff --git a/backend/internal/job/scheduler.go b/backend/internal/job/scheduler.go index a9e35077..f30e6380 100644 --- a/backend/internal/job/scheduler.go +++ b/backend/internal/job/scheduler.go @@ -26,7 +26,7 @@ func NewScheduler() (*Scheduler, error) { // Run the scheduler. // This function blocks until the context is canceled. -func (s *Scheduler) Run(ctx context.Context) { +func (s *Scheduler) Run(ctx context.Context) error { log.Println("Starting job scheduler") s.scheduler.Start() @@ -39,6 +39,8 @@ func (s *Scheduler) Run(ctx context.Context) { } else { log.Println("Job scheduler shut down") } + + return nil } func (s *Scheduler) registerJob(ctx context.Context, name string, interval string, job func(ctx context.Context) error) error { diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go index 700aceb0..3aa7c6a4 100644 --- a/backend/internal/service/app_config_service.go +++ b/backend/internal/service/app_config_service.go @@ -26,12 +26,12 @@ type AppConfigService struct { db *gorm.DB } -func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService { +func NewAppConfigService(initCtx context.Context, db *gorm.DB) *AppConfigService { service := &AppConfigService{ db: db, } - err := service.LoadDbConfig(ctx) + err := service.LoadDbConfig(initCtx) if err != nil { log.Fatalf("Failed to initialize app config service: %v", err) } diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 1a2b1cd9..b5c0760b 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -33,7 +33,7 @@ type EmailService struct { textTemplates map[string]*ttemplate.Template } -func NewEmailService(appConfigService *AppConfigService, db *gorm.DB) (*EmailService, error) { +func NewEmailService(db *gorm.DB, appConfigService *AppConfigService) (*EmailService, error) { htmlTemplates, err := email.PrepareHTMLTemplates(emailTemplatesPaths) if err != nil { return nil, fmt.Errorf("prepare html templates: %w", err) diff --git a/backend/internal/service/geolite_service.go b/backend/internal/service/geolite_service.go index 6e839e91..adca4b4b 100644 --- a/backend/internal/service/geolite_service.go +++ b/backend/internal/service/geolite_service.go @@ -23,7 +23,7 @@ import ( type GeoLiteService struct { disableUpdater bool - mutex sync.Mutex + mutex sync.RWMutex } var localhostIPNets = []*net.IPNet{ @@ -42,25 +42,22 @@ var tailscaleIPNets = []*net.IPNet{ } // NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database. -func NewGeoLiteService(ctx context.Context) *GeoLiteService { +func NewGeoLiteService() *GeoLiteService { service := &GeoLiteService{} if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl { - // Warn the user, and disable the updater. + // Warn the user, and disable the periodic updater log.Println("MAXMIND_LICENSE_KEY environment variable is empty. The GeoLite2 City database won't be updated.") service.disableUpdater = true } - go func() { - err := service.updateDatabase(ctx) - if err != nil { - log.Printf("Failed to update GeoLite2 City database: %v", err) - } - }() - return service } +func (s *GeoLiteService) DisableUpdater() bool { + return s.disableUpdater +} + // GetLocationByIP returns the country and city of the given IP address. func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string, err error) { // Check the IP address against known private IP ranges @@ -83,8 +80,8 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string } // Race condition between reading and writing the database. - s.mutex.Lock() - defer s.mutex.Unlock() + s.mutex.RLock() + defer s.mutex.RUnlock() db, err := maxminddb.Open(common.EnvConfig.GeoLiteDBPath) if err != nil { @@ -92,7 +89,10 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string } defer db.Close() - addr := netip.MustParseAddr(ipAddress) + addr, err := netip.ParseAddr(ipAddress) + if err != nil { + return "", "", fmt.Errorf("failed to parse IP address: %w", err) + } var record struct { City struct { @@ -112,18 +112,13 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string } // UpdateDatabase checks the age of the database and updates it if it's older than 14 days. -func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error { - if s.disableUpdater { - // Avoid updating the GeoLite2 City database. - return nil - } - +func (s *GeoLiteService) UpdateDatabase(parentCtx context.Context) error { if s.isDatabaseUpToDate() { - log.Println("GeoLite2 City database is up-to-date.") + log.Println("GeoLite2 City database is up-to-date") return nil } - log.Println("Updating GeoLite2 City database...") + log.Println("Updating GeoLite2 City database") downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey) ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute) @@ -145,7 +140,8 @@ func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error { } // Extract the database file directly to the target path - if err := s.extractDatabase(resp.Body); err != nil { + err = s.extractDatabase(resp.Body) + if err != nil { return fmt.Errorf("failed to extract database: %w", err) } @@ -179,10 +175,9 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error { // Iterate over the files in the tar archive for { header, err := tarReader.Next() - if err == io.EOF { + if errors.Is(err, io.EOF) { break - } - if err != nil { + } else if err != nil { return fmt.Errorf("failed to read tar archive: %w", err) } diff --git a/backend/internal/utils/servicerunner.go b/backend/internal/utils/servicerunner.go new file mode 100644 index 00000000..a1267912 --- /dev/null +++ b/backend/internal/utils/servicerunner.go @@ -0,0 +1,58 @@ +package utils + +import ( + "context" + "errors" +) + +// Source: +// https://github.com/ItalyPaleAle/traefik-forward-auth/blob/v3.5.1/pkg/utils/servicerunner.go +// Copyright (c) 2018, Thom Seddon & Contributors Copyright (c) 2023, Alessandro Segala & Contributors +// License: MIT (https://github.com/ItalyPaleAle/traefik-forward-auth/blob/v3.5.1/LICENSE.md) + +// Service is a background service +type Service func(ctx context.Context) error + +// ServiceRunner oversees a number of services running in background +type ServiceRunner struct { + services []Service +} + +// NewServiceRunner creates a new ServiceRunner +func NewServiceRunner(services ...Service) *ServiceRunner { + return &ServiceRunner{ + services: services, + } +} + +// Run all background services +func (r *ServiceRunner) Run(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + errCh := make(chan error) + for _, service := range r.services { + go func(service Service) { + // Run the service + rErr := service(ctx) + + // Ignore context canceled errors here as they generally indicate that the service is stopping + if rErr != nil && !errors.Is(rErr, context.Canceled) { + errCh <- rErr + return + } + errCh <- nil + }(service) + } + + // Wait for all services to return + errs := make([]error, 0) + for range len(r.services) { + err := <-errCh + if err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/backend/internal/utils/servicerunner_test.go b/backend/internal/utils/servicerunner_test.go new file mode 100644 index 00000000..271a4c4d --- /dev/null +++ b/backend/internal/utils/servicerunner_test.go @@ -0,0 +1,125 @@ +package utils + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// Source: +// https://github.com/ItalyPaleAle/traefik-forward-auth/blob/v3.5.1/pkg/utils/servicerunner.go +// Copyright (c) 2018, Thom Seddon & Contributors Copyright (c) 2023, Alessandro Segala & Contributors +// License: MIT (https://github.com/ItalyPaleAle/traefik-forward-auth/blob/v3.5.1/LICENSE.md) + +func TestServiceRunner_Run(t *testing.T) { + t.Run("successful services", func(t *testing.T) { + // Create a service that just returns no error after 0.2s + successService := func(ctx context.Context) error { + time.Sleep(200 * time.Millisecond) + return nil + } + + // Create a service runner with two success services + runner := NewServiceRunner(successService, successService) + + // Run the services with a timeout to avoid hanging if something goes wrong + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + // Run should return nil when all services succeed + err := runner.Run(ctx) + require.NoError(t, err) + }) + + t.Run("service with error", func(t *testing.T) { + // Create a service that returns an error + expectedErr := errors.New("service failed") + errorService := func(ctx context.Context) error { + return expectedErr + } + + // Create a service runner with one error service and one success service + successService := func(ctx context.Context) error { + time.Sleep(200 * time.Millisecond) + return nil + } + + runner := NewServiceRunner(errorService, successService) + + // Run the services with a timeout + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + // Run should return the error + err := runner.Run(ctx) + require.Error(t, err) + + // The error should contain our expected error + require.ErrorIs(t, err, expectedErr) + }) + + t.Run("context canceled", func(t *testing.T) { + // Create a service that waits until context is canceled + waitingService := func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + } + + // Create another service that returns no error quickly + quickService := func(ctx context.Context) error { + return nil + } + + runner := NewServiceRunner(waitingService, quickService) + + // Create a context that we can cancel + ctx, cancel := context.WithCancel(t.Context()) + + // Run the runner in a goroutine + errCh := make(chan error) + go func() { + errCh <- runner.Run(ctx) + }() + + // Cancel the context to trigger service shutdown + cancel() + + // Wait for the runner to finish with a timeout + select { + case err := <-errCh: + require.NoError(t, err, "expected nil error (context.Canceled should be ignored)") + case <-time.After(5 * time.Second): + t.Fatal("test timed out waiting for runner to finish") + } + }) + + t.Run("multiple errors", func(t *testing.T) { + // Create two services that return different errors + err1 := errors.New("error 1") + err2 := errors.New("error 2") + + service1 := func(ctx context.Context) error { + return err1 + } + service2 := func(ctx context.Context) error { + return err2 + } + + runner := NewServiceRunner(service1, service2) + + // Run the services + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + // Run should join all errors + err := runner.Run(ctx) + require.Error(t, err) + + // Check that both errors are included + require.ErrorIs(t, err, err1) + require.ErrorIs(t, err, err2) + }) +}