mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-15 17:53:03 +03:00
refactor: complete conversion of log calls to slog (#787)
This commit is contained in:
committed by
GitHub
parent
78266e3e4c
commit
42a861d206
@@ -27,7 +27,10 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the database
|
// Connect to the database
|
||||||
db := NewDatabase()
|
db, err := NewDatabase()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create all services
|
// Create all services
|
||||||
svc, err := initServices(ctx, db, httpClient)
|
svc, err := initServices(ctx, db, httpClient)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package bootstrap
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -25,14 +24,14 @@ import (
|
|||||||
"github.com/pocket-id/pocket-id/backend/resources"
|
"github.com/pocket-id/pocket-id/backend/resources"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewDatabase() (db *gorm.DB) {
|
func NewDatabase() (db *gorm.DB, err error) {
|
||||||
db, err := connectDatabase()
|
db, err = connectDatabase()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to database: %v", err)
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
}
|
}
|
||||||
sqlDb, err := db.DB()
|
sqlDb, err := db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to get sql.DB: %v", err)
|
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Choose the correct driver for the database provider
|
// Choose the correct driver for the database provider
|
||||||
@@ -44,18 +43,18 @@ func NewDatabase() (db *gorm.DB) {
|
|||||||
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
|
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
|
||||||
default:
|
default:
|
||||||
// Should never happen at this point
|
// Should never happen at this point
|
||||||
log.Fatalf("unsupported database provider: %s", common.EnvConfig.DbProvider)
|
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create migration driver: %v", err)
|
return nil, fmt.Errorf("failed to create migration driver: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run migrations
|
// Run migrations
|
||||||
if err := migrateDatabase(driver); err != nil {
|
if err := migrateDatabase(driver); err != nil {
|
||||||
log.Fatalf("failed to run migrations: %v", err)
|
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func migrateDatabase(driver database.Driver) error {
|
func migrateDatabase(driver database.Driver) error {
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
package bootstrap
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -18,7 +19,8 @@ func init() {
|
|||||||
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
|
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
|
||||||
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
|
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize test service: %v", err)
|
slog.Error("Failed to initialize test service", slog.Any("error", err))
|
||||||
|
os.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,10 @@ type services struct {
|
|||||||
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
|
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
|
||||||
svc = &services{}
|
svc = &services{}
|
||||||
|
|
||||||
svc.appConfigService = service.NewAppConfigService(ctx, db)
|
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create app config service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -54,7 +57,11 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
|
|||||||
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
||||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
||||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||||
svc.webauthnService = service.NewWebAuthnService(db, svc.jwtService, svc.auditLogService, svc.appConfigService)
|
|
||||||
|
svc.webauthnService, err = service.NewWebAuthnService(db, svc.jwtService, svc.auditLogService, svc.appConfigService)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,10 @@ func init() {
|
|||||||
Use: "key-rotate",
|
Use: "key-rotate",
|
||||||
Short: "Generates a new token signing key and replaces the current one",
|
Short: "Generates a new token signing key and replaces the current one",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
db := bootstrap.NewDatabase()
|
db, err := bootstrap.NewDatabase()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return keyRotate(cmd.Context(), flags, db, &common.EnvConfig)
|
return keyRotate(cmd.Context(), flags, db, &common.EnvConfig)
|
||||||
},
|
},
|
||||||
@@ -80,7 +83,10 @@ func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Init the services we need
|
// Init the services we need
|
||||||
appConfigService := service.NewAppConfigService(ctx, db)
|
appConfigService, err := service.NewAppConfigService(ctx, db)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create app config service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Get the key provider
|
// Get the key provider
|
||||||
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfigService.GetDbConfig().InstanceID.Value)
|
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfigService.GetDbConfig().InstanceID.Value)
|
||||||
|
|||||||
@@ -97,7 +97,8 @@ func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bo
|
|||||||
db := testingutils.NewDatabaseForTest(t)
|
db := testingutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Initialize app config service and create instance
|
// Initialize app config service and create instance
|
||||||
appConfigService := service.NewAppConfigService(t.Context(), db)
|
appConfigService, err := service.NewAppConfigService(t.Context(), db)
|
||||||
|
require.NoError(t, err)
|
||||||
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
||||||
|
|
||||||
// Check if key exists before rotation
|
// Check if key exists before rotation
|
||||||
@@ -147,7 +148,8 @@ func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantEr
|
|||||||
db := testingutils.NewDatabaseForTest(t)
|
db := testingutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Initialize app config service and create instance
|
// Initialize app config service and create instance
|
||||||
appConfigService := service.NewAppConfigService(t.Context(), db)
|
appConfigService, err := service.NewAppConfigService(t.Context(), db)
|
||||||
|
require.NoError(t, err)
|
||||||
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
||||||
|
|
||||||
// Get key provider
|
// Get key provider
|
||||||
|
|||||||
@@ -24,11 +24,14 @@ var oneTimeAccessTokenCmd = &cobra.Command{
|
|||||||
userArg := args[0]
|
userArg := args[0]
|
||||||
|
|
||||||
// Connect to the database
|
// Connect to the database
|
||||||
db := bootstrap.NewDatabase()
|
db, err := bootstrap.NewDatabase()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Create the access token
|
// Create the access token
|
||||||
var oneTimeAccessToken *model.OneTimeAccessToken
|
var oneTimeAccessToken *model.OneTimeAccessToken
|
||||||
err := db.Transaction(func(tx *gorm.DB) error {
|
err = db.Transaction(func(tx *gorm.DB) error {
|
||||||
// Load the user to retrieve the user ID
|
// Load the user to retrieve the user ID
|
||||||
var user model.User
|
var user model.User
|
||||||
queryCtx, queryCancel := context.WithTimeout(cmd.Context(), 10*time.Second)
|
queryCtx, queryCancel := context.WithTimeout(cmd.Context(), 10*time.Second)
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ package common
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log/slog"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/caarlos0/env/v11"
|
"github.com/caarlos0/env/v11"
|
||||||
_ "github.com/joho/godotenv/autoload"
|
_ "github.com/joho/godotenv/autoload"
|
||||||
@@ -57,7 +58,8 @@ var EnvConfig = defaultConfig()
|
|||||||
func init() {
|
func init() {
|
||||||
err := parseEnvConfig()
|
err := parseEnvConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Configuration error: %v", err)
|
slog.Error("Configuration error", slog.Any("error", err))
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tc.TestService.ResetApplicationImages(); err != nil {
|
if err := tc.TestService.ResetApplicationImages(c.Request.Context()); err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
@@ -23,7 +24,9 @@ func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtServi
|
|||||||
var err error
|
var err error
|
||||||
wkc.oidcConfig, err = wkc.computeOIDCConfiguration()
|
wkc.oidcConfig, err = wkc.computeOIDCConfiguration()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to pre-compute OpenID Connect configuration document: %v", err)
|
slog.Error("Failed to pre-compute OpenID Connect configuration document", slog.Any("error", err))
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
|
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log/slog"
|
||||||
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
@@ -18,9 +19,11 @@ var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
v, _ := binding.Validator.Engine().(*validator.Validate)
|
||||||
if err := v.RegisterValidation("username", validateUsername); err != nil {
|
err := v.RegisterValidation("username", validateUsername)
|
||||||
log.Fatalf("Failed to register custom validation: %v", err)
|
if err != nil {
|
||||||
}
|
slog.Error("Failed to register custom validation", slog.Any("error", err))
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -28,22 +27,22 @@ type AppConfigService struct {
|
|||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
|
func NewAppConfigService(ctx context.Context, db *gorm.DB) (*AppConfigService, error) {
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.LoadDbConfig(ctx)
|
err := service.LoadDbConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to initialize app config service: %v", err)
|
return nil, fmt.Errorf("failed to initialize app config service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.initInstanceID(ctx)
|
err = service.initInstanceID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to initialize instance ID: %v", err)
|
return nil, fmt.Errorf("failed to initialize instance ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return service
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDbConfig returns the application configuration.
|
// GetDbConfig returns the application configuration.
|
||||||
|
|||||||
@@ -22,7 +22,12 @@ type AuditLogService struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailService *EmailService, geoliteService *GeoLiteService) *AuditLogService {
|
func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailService *EmailService, geoliteService *GeoLiteService) *AuditLogService {
|
||||||
return &AuditLogService{db: db, appConfigService: appConfigService, emailService: emailService, geoliteService: geoliteService}
|
return &AuditLogService{
|
||||||
|
db: db,
|
||||||
|
appConfigService: appConfigService,
|
||||||
|
emailService: emailService,
|
||||||
|
geoliteService: geoliteService,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new audit log entry in the database
|
// Create creates a new audit log entry in the database
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
@@ -402,9 +402,9 @@ func (s *TestService) ResetDatabase() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TestService) ResetApplicationImages() error {
|
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
||||||
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
|
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
|
||||||
log.Printf("Error removing directory: %v", err)
|
slog.ErrorContext(ctx, "Error removing directory", slog.Any("error", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ func NewGeoLiteService(httpClient *http.Client) *GeoLiteService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize IPv6 local ranges
|
// Initialize IPv6 local ranges
|
||||||
if err := service.initializeIPv6LocalRanges(); err != nil {
|
err := service.initializeIPv6LocalRanges()
|
||||||
|
if err != nil {
|
||||||
slog.Warn("Failed to initialize IPv6 local ranges", slog.Any("error", err))
|
slog.Warn("Failed to initialize IPv6 local ranges", slog.Any("error", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -169,13 +168,13 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
|
|
||||||
userResult, err := client.Search(userSearchReq)
|
userResult, err := client.Search(userSearchReq)
|
||||||
if err != nil || len(userResult.Entries) == 0 {
|
if err != nil || len(userResult.Entries) == 0 {
|
||||||
log.Printf("Could not resolve group member DN '%s': %v", member, err)
|
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
||||||
if username == "" {
|
if username == "" {
|
||||||
log.Printf("Could not extract username from group member DN '%s'", member)
|
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -1180,9 +1179,13 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var deviceAuth model.OidcDeviceCode
|
var deviceAuth model.OidcDeviceCode
|
||||||
if err := tx.WithContext(ctx).Preload("Client.AllowedUserGroups").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
|
err := tx.
|
||||||
log.Printf("Error finding device code with user_code %s: %v", userCode, err)
|
WithContext(ctx).
|
||||||
return err
|
Preload("Client.AllowedUserGroups").
|
||||||
|
First(&deviceAuth, "user_code = ?", userCode).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error finding device code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||||
@@ -1191,17 +1194,26 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
|||||||
|
|
||||||
// Check if the user group is allowed to authorize the client
|
// Check if the user group is allowed to authorize the client
|
||||||
var user model.User
|
var user model.User
|
||||||
if err := tx.WithContext(ctx).Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
|
err = tx.
|
||||||
return err
|
WithContext(ctx).
|
||||||
|
Preload("UserGroups").
|
||||||
|
First(&user, "id = ?", userID).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error finding user groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.IsUserGroupAllowedToAuthorize(user, deviceAuth.Client) {
|
if !s.IsUserGroupAllowedToAuthorize(user, deviceAuth.Client) {
|
||||||
return &common.OidcAccessDeniedError{}
|
return &common.OidcAccessDeniedError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.WithContext(ctx).Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
|
err = tx.
|
||||||
log.Printf("Error finding device code with user_code %s: %v", userCode, err)
|
WithContext(ctx).
|
||||||
return err
|
Preload("Client").
|
||||||
|
First(&deviceAuth, "user_code = ?", userCode).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error finding device code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||||
@@ -1211,16 +1223,12 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
|||||||
deviceAuth.UserID = &userID
|
deviceAuth.UserID = &userID
|
||||||
deviceAuth.IsAuthorized = true
|
deviceAuth.IsAuthorized = true
|
||||||
|
|
||||||
if err := tx.WithContext(ctx).Save(&deviceAuth).Error; err != nil {
|
err = tx.
|
||||||
log.Printf("Error saving device auth: %v", err)
|
WithContext(ctx).
|
||||||
return err
|
Save(&deviceAuth).
|
||||||
}
|
Error
|
||||||
|
if err != nil {
|
||||||
// Verify the update was successful
|
return fmt.Errorf("error saving device auth: %w", err)
|
||||||
var verifiedAuth model.OidcDeviceCode
|
|
||||||
if err := tx.WithContext(ctx).First(&verifiedAuth, "device_code = ?", deviceAuth.DeviceCode).Error; err != nil {
|
|
||||||
log.Printf("Error verifying update: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create user authorization if needed
|
// Create user authorization if needed
|
||||||
@@ -1229,15 +1237,16 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auditLogData := model.AuditLogData{"clientName": deviceAuth.Client.Name}
|
||||||
if !hasAuthorizedClient {
|
if !hasAuthorizedClient {
|
||||||
err := s.createAuthorizedClientInternal(ctx, userID, deviceAuth.ClientID, deviceAuth.Scope, tx)
|
err = s.createAuthorizedClientInternal(ctx, userID, deviceAuth.ClientID, deviceAuth.Scope, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
|
s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, auditLogData, tx)
|
||||||
} else {
|
} else {
|
||||||
s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
|
s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, auditLogData, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit().Error
|
return tx.Commit().Error
|
||||||
@@ -1428,7 +1437,7 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
|
|||||||
case isClientAssertion:
|
case isClientAssertion:
|
||||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
|
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
|
slog.WarnContext(ctx, "Invalid assertion for client", slog.String("client", client.ID), slog.Any("error", err))
|
||||||
return nil, &common.OidcClientAssertionInvalidError{}
|
return nil, &common.OidcClientAssertionInvalidError{}
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -35,7 +34,13 @@ type UserService struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService) *UserService {
|
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService) *UserService {
|
||||||
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService}
|
return &UserService{
|
||||||
|
db: db,
|
||||||
|
jwtService: jwtService,
|
||||||
|
auditLogService: auditLogService,
|
||||||
|
emailService: emailService,
|
||||||
|
appConfigService: appConfigService,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
|
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
|
||||||
@@ -47,7 +52,8 @@ func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPa
|
|||||||
|
|
||||||
if searchTerm != "" {
|
if searchTerm != "" {
|
||||||
searchPattern := "%" + searchTerm + "%"
|
searchPattern := "%" + searchTerm + "%"
|
||||||
query = query.Where("email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
|
query = query.Where(
|
||||||
|
"email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
|
||||||
searchPattern, searchPattern, searchPattern, searchPattern)
|
searchPattern, searchPattern, searchPattern, searchPattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,13 +126,14 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
|||||||
defaultPictureBytes := defaultPicture.Bytes()
|
defaultPictureBytes := defaultPicture.Bytes()
|
||||||
go func() {
|
go func() {
|
||||||
// Ensure the directory exists
|
// Ensure the directory exists
|
||||||
err = os.MkdirAll(defaultProfilePicturesDir, os.ModePerm)
|
errInternal := os.MkdirAll(defaultProfilePicturesDir, os.ModePerm)
|
||||||
if err != nil {
|
if errInternal != nil {
|
||||||
log.Printf("Failed to create directory for default profile picture: %v", err)
|
slog.Error("Failed to create directory for default profile picture", slog.Any("error", errInternal))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath); err != nil {
|
errInternal = utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath)
|
||||||
log.Printf("Failed to cache default profile picture for initials %s: %v", user.Initials(), err)
|
if errInternal != nil {
|
||||||
|
slog.Error("Failed to cache default profile picture for initials", slog.String("initials", user.Initials()), slog.Any("error", errInternal))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ type WebAuthnService struct {
|
|||||||
appConfigService *AppConfigService
|
appConfigService *AppConfigService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService {
|
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) (*WebAuthnService, error) {
|
||||||
webauthnConfig := &webauthn.Config{
|
wa, err := webauthn.New(&webauthn.Config{
|
||||||
RPDisplayName: appConfigService.GetDbConfig().AppName.Value,
|
RPDisplayName: appConfigService.GetDbConfig().AppName.Value,
|
||||||
RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL),
|
RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL),
|
||||||
RPOrigins: []string{common.EnvConfig.AppURL},
|
RPOrigins: []string{common.EnvConfig.AppURL},
|
||||||
@@ -45,15 +45,18 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
|
|||||||
TimeoutUVD: time.Second * 60,
|
TimeoutUVD: time.Second * 60,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to init webauthn object: %w", err)
|
||||||
}
|
}
|
||||||
wa, _ := webauthn.New(webauthnConfig)
|
|
||||||
return &WebAuthnService{
|
return &WebAuthnService{
|
||||||
db: db,
|
db: db,
|
||||||
webAuthn: wa,
|
webAuthn: wa,
|
||||||
jwtService: jwtService,
|
jwtService: jwtService,
|
||||||
auditLogService: auditLogService,
|
auditLogService: auditLogService,
|
||||||
appConfigService: appConfigService,
|
appConfigService: appConfigService,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/resources"
|
"github.com/pocket-id/pocket-id/backend/resources"
|
||||||
@@ -57,12 +57,13 @@ func loadAAGUIDsFromFile() {
|
|||||||
// Read from embedded file system
|
// Read from embedded file system
|
||||||
data, err := resources.FS.ReadFile("aaguids.json")
|
data, err := resources.FS.ReadFile("aaguids.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error reading embedded AAGUID file: %v", err)
|
slog.Error("Error reading embedded AAGUID file", slog.Any("error", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &aaguidMap); err != nil {
|
err = json.Unmarshal(data, &aaguidMap)
|
||||||
log.Printf("Error unmarshalling AAGUID data: %v", err)
|
if err != nil {
|
||||||
|
slog.Error("Error unmarshalling AAGUID data", slog.Any("error", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
|
|||||||
|
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
go func() {
|
go func() {
|
||||||
err = imaging.Encode(pw, img, imaging.PNG)
|
innerErr := imaging.Encode(pw, img, imaging.PNG)
|
||||||
if err != nil {
|
if innerErr != nil {
|
||||||
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", err))
|
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", innerErr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pw.Close()
|
pw.Close()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package signals
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -28,11 +28,11 @@ func SignalContext(parentCtx context.Context) context.Context {
|
|||||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||||
go func() {
|
go func() {
|
||||||
<-sigCh
|
<-sigCh
|
||||||
log.Println("Received interrupt signal. Shutting down…")
|
slog.Info("Received interrupt signal. Shutting down…")
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
<-sigCh
|
<-sigCh
|
||||||
log.Println("Received a second interrupt signal. Forcing an immediate shutdown.")
|
slog.Warn("Received a second interrupt signal. Forcing an immediate shutdown.")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user