From 6bdf5fa37ac4f47e32b3f727961b18f80e1b4259 Mon Sep 17 00:00:00 2001 From: Kyle Mendell Date: Mon, 29 Sep 2025 10:07:55 -0500 Subject: [PATCH] feat: support for url based icons (#840) Co-authored-by: Elias Schneider --- .../internal/bootstrap/services_bootstrap.go | 2 +- backend/internal/common/env_config.go | 20 ++ backend/internal/dto/oidc_dto.go | 2 + backend/internal/model/oidc.go | 13 +- backend/internal/service/geolite_service.go | 105 +-------- .../internal/service/geolite_service_test.go | 220 ------------------ backend/internal/service/oidc_service.go | 198 +++++++++++----- backend/internal/utils/file_util.go | 29 +++ backend/internal/utils/ip_util.go | 87 +++++++ backend/internal/utils/ip_util_test.go | 159 +++++++++++++ frontend/messages/en.json | 4 +- .../lib/components/form/url-file-input.svelte | 85 +++++++ frontend/src/lib/components/image-box.svelte | 17 +- frontend/src/lib/types/oidc.type.ts | 3 +- .../forms/app-config-general-form.svelte | 3 +- .../oidc-clients/oidc-client-form.svelte | 74 +++--- .../oidc-client-image-input.svelte | 44 ++++ .../oidc-clients/oidc-client-list.svelte | 25 +- .../apps/authorized-oidc-client-card.svelte | 2 +- 19 files changed, 650 insertions(+), 442 deletions(-) delete mode 100644 backend/internal/service/geolite_service_test.go create mode 100644 backend/internal/utils/ip_util.go create mode 100644 backend/internal/utils/ip_util_test.go create mode 100644 frontend/src/lib/components/form/url-file-input.svelte create mode 100644 frontend/src/routes/settings/admin/oidc-clients/oidc-client-image-input.svelte diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go index fe2c5f61..c7f085a0 100644 --- a/backend/internal/bootstrap/services_bootstrap.go +++ b/backend/internal/bootstrap/services_bootstrap.go @@ -56,7 +56,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima return nil, fmt.Errorf("failed to create WebAuthn service: %w", err) } - svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService) + svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient) if err != nil { return nil, fmt.Errorf("failed to create OIDC service: %w", err) } diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go index ce213fb1..d279114a 100644 --- a/backend/internal/common/env_config.go +++ b/backend/internal/common/env_config.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log/slog" + "net" "net/url" "os" "reflect" @@ -180,6 +181,25 @@ func validateEnvConfig(config *EnvConfigSchema) error { return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage) } + // Validate LOCAL_IPV6_RANGES + ranges := strings.Split(config.LocalIPv6Ranges, ",") + for _, rangeStr := range ranges { + rangeStr = strings.TrimSpace(rangeStr) + if rangeStr == "" { + continue + } + + _, ipNet, err := net.ParseCIDR(rangeStr) + if err != nil { + return fmt.Errorf("invalid LOCAL_IPV6_RANGES '%s': %w", rangeStr, err) + } + + if ipNet.IP.To4() != nil { + return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr) + } + + } + return nil } diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index 9afcb6f3..110cc585 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -38,6 +38,8 @@ type OidcClientUpdateDto struct { RequiresReauthentication bool `json:"requiresReauthentication"` Credentials OidcClientCredentialsDto `json:"credentials"` LaunchURL *string `json:"launchURL" binding:"omitempty,url"` + HasLogo bool `json:"hasLogo"` + LogoURL *string `json:"logoUrl"` } type OidcClientCreateDto struct { diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 3521e7cf..d118ee15 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -6,8 +6,6 @@ import ( "fmt" "strings" - "gorm.io/gorm" - datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" ) @@ -54,7 +52,6 @@ type OidcClient struct { CallbackURLs UrlList LogoutCallbackURLs UrlList ImageType *string - HasLogo bool `gorm:"-"` IsPublic bool PkceEnabled bool RequiresReauthentication bool @@ -67,6 +64,10 @@ type OidcClient struct { UserAuthorizedOidcClients []UserAuthorizedOidcClient `gorm:"foreignKey:ClientID;references:ID"` } +func (c OidcClient) HasLogo() bool { + return c.ImageType != nil && *c.ImageType != "" +} + type OidcRefreshToken struct { Base @@ -89,12 +90,6 @@ func (c OidcRefreshToken) Scopes() []string { return strings.Split(c.Scope, " ") } -func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { - // Compute HasLogo field - c.HasLogo = c.ImageType != nil && *c.ImageType != "" - return nil -} - type OidcClientCredentials struct { //nolint:recvcheck FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"` } diff --git a/backend/internal/service/geolite_service.go b/backend/internal/service/geolite_service.go index 3ce847e4..0d5baa58 100644 --- a/backend/internal/service/geolite_service.go +++ b/backend/internal/service/geolite_service.go @@ -13,35 +13,19 @@ import ( "net/netip" "os" "path/filepath" - "strings" "sync" "time" "github.com/oschwald/maxminddb-golang/v2" + "github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/common" ) type GeoLiteService struct { - httpClient *http.Client - disableUpdater bool - mutex sync.RWMutex - localIPv6Ranges []*net.IPNet -} - -var localhostIPNets = []*net.IPNet{ - {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 - {IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128 -} - -var privateLanIPNets = []*net.IPNet{ - {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 - {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 - {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 -} - -var tailscaleIPNets = []*net.IPNet{ - {IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10 + httpClient *http.Client + disableUpdater bool + mutex sync.RWMutex } // NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database. @@ -56,67 +40,9 @@ func NewGeoLiteService(httpClient *http.Client) *GeoLiteService { service.disableUpdater = true } - // Initialize IPv6 local ranges - err := service.initializeIPv6LocalRanges() - if err != nil { - slog.Warn("Failed to initialize IPv6 local ranges", slog.Any("error", err)) - } - return service } -// initializeIPv6LocalRanges parses the LOCAL_IPV6_RANGES environment variable -func (s *GeoLiteService) initializeIPv6LocalRanges() error { - rangesEnv := common.EnvConfig.LocalIPv6Ranges - if rangesEnv == "" { - return nil // No local IPv6 ranges configured - } - - ranges := strings.Split(rangesEnv, ",") - localRanges := make([]*net.IPNet, 0, len(ranges)) - - for _, rangeStr := range ranges { - rangeStr = strings.TrimSpace(rangeStr) - if rangeStr == "" { - continue - } - - _, ipNet, err := net.ParseCIDR(rangeStr) - if err != nil { - return fmt.Errorf("invalid IPv6 range '%s': %w", rangeStr, err) - } - - // Ensure it's an IPv6 range - if ipNet.IP.To4() != nil { - return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr) - } - - localRanges = append(localRanges, ipNet) - } - - s.localIPv6Ranges = localRanges - - if len(localRanges) > 0 { - slog.Info("Initialized IPv6 local ranges", slog.Int("count", len(localRanges))) - } - return nil -} - -// isLocalIPv6 checks if the given IPv6 address is within any of the configured local ranges -func (s *GeoLiteService) isLocalIPv6(ip net.IP) bool { - if ip.To4() != nil { - return false // Not an IPv6 address - } - - for _, localRange := range s.localIPv6Ranges { - if localRange.Contains(ip) { - return true - } - } - - return false -} - func (s *GeoLiteService) DisableUpdater() bool { return s.disableUpdater } @@ -129,26 +55,17 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string // Check the IP address against known private IP ranges if ip := net.ParseIP(ipAddress); ip != nil { - // Check IPv6 local ranges first - if s.isLocalIPv6(ip) { + if utils.IsLocalIPv6(ip) { return "Internal Network", "LAN", nil } - - // Check existing IPv4 ranges - for _, ipNet := range tailscaleIPNets { - if ipNet.Contains(ip) { - return "Internal Network", "Tailscale", nil - } + if utils.IsTailscaleIP(ip) { + return "Internal Network", "Tailscale", nil } - for _, ipNet := range privateLanIPNets { - if ipNet.Contains(ip) { - return "Internal Network", "LAN", nil - } + if utils.IsPrivateIP(ip) { + return "Internal Network", "LAN", nil } - for _, ipNet := range localhostIPNets { - if ipNet.Contains(ip) { - return "Internal Network", "localhost", nil - } + if utils.IsLocalhostIP(ip) { + return "Internal Network", "localhost", nil } } diff --git a/backend/internal/service/geolite_service_test.go b/backend/internal/service/geolite_service_test.go deleted file mode 100644 index 638c7721..00000000 --- a/backend/internal/service/geolite_service_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package service - -import ( - "net" - "net/http" - "testing" - - "github.com/pocket-id/pocket-id/backend/internal/common" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestGeoLiteService_IPv6LocalRanges(t *testing.T) { - tests := []struct { - name string - localRanges string - testIP string - expectedCountry string - expectedCity string - expectError bool - }{ - { - name: "IPv6 in local range", - localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56", - testIP: "2001:0db8:abcd:000::1", - expectedCountry: "Internal Network", - expectedCity: "LAN", - expectError: false, - }, - { - name: "IPv6 not in local range", - localRanges: "2001:0db8:abcd:000::/56", - testIP: "2001:0db8:ffff:000::1", - expectError: true, - }, - { - name: "Multiple ranges - second range match", - localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56", - testIP: "2001:0db8:abcd:001::1", - expectedCountry: "Internal Network", - expectedCity: "LAN", - expectError: false, - }, - { - name: "Empty local ranges", - localRanges: "", - testIP: "2001:0db8:abcd:000::1", - expectError: true, - }, - { - name: "IPv4 private address still works", - localRanges: "2001:0db8:abcd:000::/56", - testIP: "192.168.1.1", - expectedCountry: "Internal Network", - expectedCity: "LAN", - expectError: false, - }, - { - name: "IPv6 loopback", - localRanges: "2001:0db8:abcd:000::/56", - testIP: "::1", - expectedCountry: "Internal Network", - expectedCity: "localhost", - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - originalConfig := common.EnvConfig.LocalIPv6Ranges - common.EnvConfig.LocalIPv6Ranges = tt.localRanges - defer func() { - common.EnvConfig.LocalIPv6Ranges = originalConfig - }() - - service := NewGeoLiteService(&http.Client{}) - - country, city, err := service.GetLocationByIP(tt.testIP) - - if tt.expectError { - if err == nil && country != "Internal Network" { - t.Errorf("Expected error or internal network classification for external IP") - } - } else { - require.NoError(t, err) - assert.Equal(t, tt.expectedCountry, country) - assert.Equal(t, tt.expectedCity, city) - } - }) - } -} - -func TestGeoLiteService_isLocalIPv6(t *testing.T) { - tests := []struct { - name string - localRanges string - testIP string - expected bool - }{ - { - name: "Valid IPv6 in range", - localRanges: "2001:0db8:abcd:000::/56", - testIP: "2001:0db8:abcd:000::1", - expected: true, - }, - { - name: "Valid IPv6 not in range", - localRanges: "2001:0db8:abcd:000::/56", - testIP: "2001:0db8:ffff:000::1", - expected: false, - }, - { - name: "IPv4 address should return false", - localRanges: "2001:0db8:abcd:000::/56", - testIP: "192.168.1.1", - expected: false, - }, - { - name: "No ranges configured", - localRanges: "", - testIP: "2001:0db8:abcd:000::1", - expected: false, - }, - { - name: "Edge of range", - localRanges: "2001:0db8:abcd:000::/56", - testIP: "2001:0db8:abcd:00ff:ffff:ffff:ffff:ffff", - expected: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - originalConfig := common.EnvConfig.LocalIPv6Ranges - common.EnvConfig.LocalIPv6Ranges = tt.localRanges - defer func() { - common.EnvConfig.LocalIPv6Ranges = originalConfig - }() - - service := NewGeoLiteService(&http.Client{}) - ip := net.ParseIP(tt.testIP) - if ip == nil { - t.Fatalf("Invalid test IP: %s", tt.testIP) - } - - result := service.isLocalIPv6(ip) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) { - tests := []struct { - name string - envValue string - expectError bool - expectCount int - }{ - { - name: "Valid IPv6 ranges", - envValue: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56", - expectError: false, - expectCount: 2, - }, - { - name: "Empty environment variable", - envValue: "", - expectError: false, - expectCount: 0, - }, - { - name: "Invalid CIDR notation", - envValue: "2001:0db8:abcd:000::/999", - expectError: true, - expectCount: 0, - }, - { - name: "IPv4 range in IPv6 env var", - envValue: "192.168.1.0/24", - expectError: true, - expectCount: 0, - }, - { - name: "Mixed valid and invalid ranges", - envValue: "2001:0db8:abcd:000::/56,invalid-range", - expectError: true, - expectCount: 0, - }, - { - name: "Whitespace handling", - envValue: " 2001:0db8:abcd:000::/56 , 2001:0db8:abcd:001::/56 ", - expectError: false, - expectCount: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - originalConfig := common.EnvConfig.LocalIPv6Ranges - common.EnvConfig.LocalIPv6Ranges = tt.envValue - defer func() { - common.EnvConfig.LocalIPv6Ranges = originalConfig - }() - - service := &GeoLiteService{ - httpClient: &http.Client{}, - } - - err := service.initializeIPv6LocalRanges() - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - - assert.Len(t, service.localIPv6Ranges, tt.expectCount) - }) - } -} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index d085edfd..5d21b016 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -8,9 +8,12 @@ import ( "encoding/json" "errors" "fmt" + "io" "log/slog" "mime/multipart" + "net" "net/http" + "net/url" "os" "regexp" "slices" @@ -66,6 +69,7 @@ func NewOidcService( auditLogService *AuditLogService, customClaimService *CustomClaimService, webAuthnService *WebAuthnService, + httpClient *http.Client, ) (s *OidcService, err error) { s = &OidcService{ db: db, @@ -74,6 +78,7 @@ func NewOidcService( auditLogService: auditLogService, customClaimService: customClaimService, webAuthnService: webAuthnService, + httpClient: httpClient, } // Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace @@ -714,6 +719,11 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina } func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { + tx := s.db.Begin() + defer func() { + tx.Rollback() + }() + client := model.OidcClient{ Base: model.Base{ ID: input.ID, @@ -722,7 +732,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea } updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto) - err := s.db. + err := tx. WithContext(ctx). Create(&client). Error @@ -733,33 +743,11 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea return model.OidcClient{}, err } - return client, nil -} - -func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) { - tx := s.db.Begin() - defer func() { - tx.Rollback() - }() - - var client model.OidcClient - err := tx. - WithContext(ctx). - Preload("CreatedBy"). - First(&client, "id = ?", clientID). - Error - if err != nil { - return model.OidcClient{}, err - } - - updateOIDCClientModelFromDto(&client, &input) - - err = tx. - WithContext(ctx). - Save(&client). - Error - if err != nil { - return model.OidcClient{}, err + if input.LogoURL != nil { + err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL) + if err != nil { + return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err) + } } err = tx.Commit().Error @@ -770,6 +758,36 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d return client, nil } +func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) { + tx := s.db.Begin() + defer func() { tx.Rollback() }() + + var client model.OidcClient + if err := tx.WithContext(ctx). + Preload("CreatedBy"). + First(&client, "id = ?", clientID).Error; err != nil { + return model.OidcClient{}, err + } + + updateOIDCClientModelFromDto(&client, &input) + + if err := tx.WithContext(ctx).Save(&client).Error; err != nil { + return model.OidcClient{}, err + } + + if input.LogoURL != nil { + err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL) + if err != nil { + return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err) + } + } + + if err := tx.Commit().Error; err != nil { + return model.OidcClient{}, err + } + return client, nil +} + func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientUpdateDto) { // Base fields client.Name = input.Name @@ -883,41 +901,14 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil } tx := s.db.Begin() - defer func() { + + err = s.updateClientLogoType(ctx, tx, clientID, fileType) + if err != nil { tx.Rollback() - }() - - var client model.OidcClient - err = tx. - WithContext(ctx). - First(&client, "id = ?", clientID). - Error - if err != nil { return err } - if client.ImageType != nil && fileType != *client.ImageType { - oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType) - if err := os.Remove(oldImagePath); err != nil { - return err - } - } - - client.ImageType = &fileType - err = tx. - WithContext(ctx). - Save(&client). - Error - if err != nil { - return err - } - - err = tx.Commit().Error - if err != nil { - return err - } - - return nil + return tx.Commit().Error } func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error { @@ -941,6 +932,7 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err oldImageType := *client.ImageType client.ImageType = nil + err = tx. WithContext(ctx). Save(&client). @@ -1333,7 +1325,7 @@ func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, us Client: dto.OidcClientMetaDataDto{ ID: deviceAuth.Client.ID, Name: deviceAuth.Client.Name, - HasLogo: deviceAuth.Client.HasLogo, + HasLogo: deviceAuth.Client.HasLogo(), }, Scope: deviceAuth.Scope, AuthorizationRequired: !hasAuthorizedClient, @@ -1468,7 +1460,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri ID: client.ID, Name: client.Name, LaunchURL: client.LaunchURL, - HasLogo: client.HasLogo, + HasLogo: client.HasLogo(), }, LastUsedAt: lastUsedAt, } @@ -1889,3 +1881,87 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str return s.IsUserGroupAllowedToAuthorize(user, client), nil } + +func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string) error { + u, err := url.Parse(raw) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second) + defer cancel() + + r := net.Resolver{} + ips, err := r.LookupIPAddr(ctx, u.Hostname()) + if err != nil || len(ips) == 0 { + return fmt.Errorf("cannot resolve hostname") + } + + // Prevents SSRF by allowing only public IPs + for _, addr := range ips { + if utils.IsPrivateIP(addr.IP) { + return fmt.Errorf("private IP addresses are not allowed") + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil) + if err != nil { + return err + } + req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher") + req.Header.Set("Accept", "image/*") + + resp, err := s.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to fetch logo: %s", resp.Status) + } + + const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB + if resp.ContentLength > maxLogoSize { + return fmt.Errorf("logo is too large") + } + + // Prefer extension in path if supported + ext := utils.GetFileExtension(u.Path) + if ext == "" || utils.GetImageMimeType(ext) == "" { + // Otherwise, try to detect from content type + ext = utils.GetImageExtensionFromMimeType(resp.Header.Get("Content-Type")) + } + + if ext == "" { + return &common.FileTypeNotSupportedError{} + } + + imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + ext + err = utils.SaveFileStream(io.LimitReader(resp.Body, maxLogoSize+1), imagePath) + if err != nil { + return err + } + + if err := s.updateClientLogoType(ctx, tx, clientID, ext); err != nil { + return err + } + + return nil +} + +func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string) error { + uploadsDir := common.EnvConfig.UploadPath + "/oidc-client-images" + + var client model.OidcClient + if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil { + return err + } + if client.ImageType != nil && *client.ImageType != ext { + old := fmt.Sprintf("%s/%s.%s", uploadsDir, client.ID, *client.ImageType) + _ = os.Remove(old) + } + client.ImageType = &ext + return tx.WithContext(ctx).Save(&client).Error + +} diff --git a/backend/internal/utils/file_util.go b/backend/internal/utils/file_util.go index 8216e012..89bcc9a0 100644 --- a/backend/internal/utils/file_util.go +++ b/backend/internal/utils/file_util.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "mime" "mime/multipart" "os" "path/filepath" @@ -57,6 +58,34 @@ func GetImageMimeType(ext string) string { } } +func GetImageExtensionFromMimeType(mimeType string) string { + // Normalize and strip parameters like `; charset=utf-8` + mt := strings.TrimSpace(strings.ToLower(mimeType)) + if v, _, err := mime.ParseMediaType(mt); err == nil { + mt = v + } + switch mt { + case "image/jpeg", "image/jpg": + return "jpg" + case "image/png": + return "png" + case "image/svg+xml": + return "svg" + case "image/x-icon", "image/vnd.microsoft.icon": + return "ico" + case "image/gif": + return "gif" + case "image/webp": + return "webp" + case "image/avif": + return "avif" + case "image/heic", "image/heif": + return "heic" + default: + return "" + } +} + func CopyEmbeddedFileToDisk(srcFilePath, destFilePath string) error { srcFile, err := resources.FS.Open(srcFilePath) if err != nil { diff --git a/backend/internal/utils/ip_util.go b/backend/internal/utils/ip_util.go new file mode 100644 index 00000000..9832046b --- /dev/null +++ b/backend/internal/utils/ip_util.go @@ -0,0 +1,87 @@ +package utils + +import ( + "net" + "strings" + + "github.com/pocket-id/pocket-id/backend/internal/common" +) + +var localIPv6Ranges []*net.IPNet + +var localhostIPNets = []*net.IPNet{ + {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 + {IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128 +} + +var privateLanIPNets = []*net.IPNet{ + {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 + {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 + {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 +} + +var tailscaleIPNets = []*net.IPNet{ + {IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10 +} + +func IsLocalIPv6(ip net.IP) bool { + if ip.To4() != nil { + return false + } + + return listContainsIP(localIPv6Ranges, ip) +} + +func IsLocalhostIP(ip net.IP) bool { + return listContainsIP(localhostIPNets, ip) +} + +func IsPrivateLanIP(ip net.IP) bool { + if ip.To4() == nil { + return false + } + + return listContainsIP(privateLanIPNets, ip) +} + +func IsTailscaleIP(ip net.IP) bool { + if ip.To4() == nil { + return false + } + + return listContainsIP(tailscaleIPNets, ip) +} + +func IsPrivateIP(ip net.IP) bool { + return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip) +} + +func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool { + for _, ipNet := range ipNets { + if ipNet.Contains(ip) { + return true + } + } + return false +} + +func loadLocalIPv6Ranges() { + localIPv6Ranges = nil + ranges := strings.Split(common.EnvConfig.LocalIPv6Ranges, ",") + + for _, rangeStr := range ranges { + rangeStr = strings.TrimSpace(rangeStr) + if rangeStr == "" { + continue + } + + _, ipNet, err := net.ParseCIDR(rangeStr) + if err == nil { + localIPv6Ranges = append(localIPv6Ranges, ipNet) + } + } +} + +func init() { + loadLocalIPv6Ranges() +} diff --git a/backend/internal/utils/ip_util_test.go b/backend/internal/utils/ip_util_test.go new file mode 100644 index 00000000..01c7bf68 --- /dev/null +++ b/backend/internal/utils/ip_util_test.go @@ -0,0 +1,159 @@ +package utils + +import ( + "net" + "testing" + + "github.com/pocket-id/pocket-id/backend/internal/common" +) + +func TestIsLocalhostIP(t *testing.T) { + tests := []struct { + ip string + expected bool + }{ + {"127.0.0.1", true}, + {"127.255.255.255", true}, + {"::1", true}, + {"192.168.1.1", false}, + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + if got := IsLocalhostIP(ip); got != tt.expected { + t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected) + } + } +} + +func TestIsPrivateLanIP(t *testing.T) { + tests := []struct { + ip string + expected bool + }{ + {"10.0.0.1", true}, + {"172.16.5.4", true}, + {"192.168.100.200", true}, + {"8.8.8.8", false}, + {"::1", false}, // IPv6 should return false + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + if got := IsPrivateLanIP(ip); got != tt.expected { + t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected) + } + } +} + +func TestIsTailscaleIP(t *testing.T) { + tests := []struct { + ip string + expected bool + }{ + {"100.64.0.1", true}, + {"100.127.255.254", true}, + {"8.8.8.8", false}, + {"::1", false}, // IPv6 should return false + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + if got := IsTailscaleIP(ip); got != tt.expected { + t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected) + } + } +} + +func TestIsLocalIPv6(t *testing.T) { + // Save and restore env config + origRanges := common.EnvConfig.LocalIPv6Ranges + defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }() + + common.EnvConfig.LocalIPv6Ranges = "fd00::/8,fc00::/7" + localIPv6Ranges = nil // reset + loadLocalIPv6Ranges() + + tests := []struct { + ip string + expected bool + }{ + {"fd00::1", true}, + {"fc00::abcd", true}, + {"::1", false}, // loopback handled separately + {"192.168.1.1", false}, // IPv4 should return false + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + if got := IsLocalIPv6(ip); got != tt.expected { + t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected) + } + } +} + +func TestIsPrivateIP(t *testing.T) { + // Save and restore env config + origRanges := common.EnvConfig.LocalIPv6Ranges + defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }() + + common.EnvConfig.LocalIPv6Ranges = "fd00::/8" + localIPv6Ranges = nil // reset + loadLocalIPv6Ranges() + + tests := []struct { + ip string + expected bool + }{ + {"127.0.0.1", true}, // localhost + {"192.168.1.1", true}, // private LAN + {"100.64.0.1", true}, // Tailscale + {"fd00::1", true}, // local IPv6 + {"8.8.8.8", false}, // public IPv4 + {"2001:4860:4860::8888", false}, // public IPv6 + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + if got := IsPrivateIP(ip); got != tt.expected { + t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected) + } + } +} + +func TestListContainsIP(t *testing.T) { + _, ipNet1, _ := net.ParseCIDR("10.0.0.0/8") + _, ipNet2, _ := net.ParseCIDR("192.168.0.0/16") + + list := []*net.IPNet{ipNet1, ipNet2} + + tests := []struct { + ip string + expected bool + }{ + {"10.1.1.1", true}, + {"192.168.5.5", true}, + {"172.16.0.1", false}, + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + if got := listContainsIP(list, ip); got != tt.expected { + t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected) + } + } +} + +func TestInit_LocalIPv6Ranges(t *testing.T) { + // Save and restore env config + origRanges := common.EnvConfig.LocalIPv6Ranges + defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }() + + common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7" + localIPv6Ranges = nil + loadLocalIPv6Ranges() + + if len(localIPv6Ranges) != 2 { + t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges)) + } +} diff --git a/frontend/messages/en.json b/frontend/messages/en.json index 8dec4a83..f4a8d2df 100644 --- a/frontend/messages/en.json +++ b/frontend/messages/en.json @@ -450,5 +450,7 @@ "display_name": "Display Name", "configure_application_images": "Configure Application Images", "ui_config_disabled_info_title": "UI Configuration Disabled", - "ui_config_disabled_info_description": "The UI configuration is disabled because the application configuration settings are managed through environment variables. Some settings may not be editable." + "ui_config_disabled_info_description": "The UI configuration is disabled because the application configuration settings are managed through environment variables. Some settings may not be editable.", + "logo_from_url_description": "Paste a direct image URL (svg, png, webp). Find icons at Selfh.st Icons or Dashboard Icons.", + "invalid_url": "Invalid URL" } diff --git a/frontend/src/lib/components/form/url-file-input.svelte b/frontend/src/lib/components/form/url-file-input.svelte new file mode 100644 index 00000000..05aed37b --- /dev/null +++ b/frontend/src/lib/components/form/url-file-input.svelte @@ -0,0 +1,85 @@ + + +
+ + + + + + + (url = e.currentTarget.value)} + onfocusout={handleUrlChange} + aria-invalid={hasError} + /> + {#if hasError} +

{m.invalid_url()}

+ {/if} + +

+ +

+
+
+
diff --git a/frontend/src/lib/components/image-box.svelte b/frontend/src/lib/components/image-box.svelte index 29cae2d3..862ac6fd 100644 --- a/frontend/src/lib/components/image-box.svelte +++ b/frontend/src/lib/components/image-box.svelte @@ -1,10 +1,25 @@
- + {#if error} + + {:else} + (error = true)} + /> + {/if}
diff --git a/frontend/src/lib/types/oidc.type.ts b/frontend/src/lib/types/oidc.type.ts index f46a1f36..18087d88 100644 --- a/frontend/src/lib/types/oidc.type.ts +++ b/frontend/src/lib/types/oidc.type.ts @@ -46,7 +46,8 @@ export type OidcClientUpdateWithLogo = OidcClientUpdate & { }; export type OidcClientCreateWithLogo = OidcClientCreate & { - logo: File | null | undefined; + logo?: File | null; + logoUrl?: string; }; export type OidcDeviceCodeInfo = { diff --git a/frontend/src/routes/settings/admin/application-configuration/forms/app-config-general-form.svelte b/frontend/src/routes/settings/admin/application-configuration/forms/app-config-general-form.svelte index 060aebec..8f0738a4 100644 --- a/frontend/src/routes/settings/admin/application-configuration/forms/app-config-general-form.svelte +++ b/frontend/src/routes/settings/admin/application-configuration/forms/app-config-general-form.svelte @@ -41,7 +41,7 @@ accentColor: z.string() }); - let { inputs, ...form } = $derived(createForm(formSchema, appConfig)); + let { inputs, ...form } = $derived(createForm(formSchema, updatedAppConfig)); async function onSubmit() { const data = form.validate(); @@ -69,7 +69,6 @@ description={m.whether_the_users_should_be_able_to_edit_their_own_account_details()} bind:checked={$inputs.allowOwnAccountEdit.value} /> - - import FileInput from '$lib/components/form/file-input.svelte'; import FormInput from '$lib/components/form/form-input.svelte'; import SwitchWithLabel from '$lib/components/form/switch-with-label.svelte'; - import ImageBox from '$lib/components/image-box.svelte'; import { Button } from '$lib/components/ui/button'; - import Label from '$lib/components/ui/label/label.svelte'; import { m } from '$lib/paraglide/messages'; import type { OidcClient, @@ -21,6 +18,7 @@ import { z } from 'zod/v4'; import FederatedIdentitiesInput from './federated-identities-input.svelte'; import OidcCallbackUrlInput from './oidc-callback-url-input.svelte'; + import OidcClientImageInput from './oidc-client-image-input.svelte'; let { callback, @@ -31,7 +29,6 @@ callback: (client: OidcClientCreateWithLogo | OidcClientUpdateWithLogo) => Promise; mode: 'create' | 'update'; } = $props(); - let isLoading = $state(false); let showAdvancedOptions = $state(false); let logo = $state(); @@ -50,7 +47,8 @@ launchURL: existingClient?.launchURL || '', credentials: { federatedIdentities: existingClient?.credentials?.federatedIdentities || [] - } + }, + logoUrl: '' }; const formSchema = z.object({ @@ -71,6 +69,7 @@ pkceEnabled: z.boolean(), requiresReauthentication: z.boolean(), launchURL: optionalUrl, + logoUrl: optionalUrl, credentials: z.object({ federatedIdentities: z.array( z.object({ @@ -90,30 +89,42 @@ const data = form.validate(); if (!data) return; isLoading = true; + const success = await callback({ ...data, - logo + logo: $inputs.logoUrl?.value ? null : logo, + logoUrl: $inputs.logoUrl?.value }); - // Reset form if client was successfully created + + const hasLogo = logo != null || !!$inputs.logoUrl?.value; + if (success && existingClient && hasLogo) { + logoDataURL = cachedOidcClientLogo.getUrl(existingClient.id); + } + if (success && !existingClient) form.reset(); isLoading = false; } - function onLogoChange(e: Event) { - const file = (e.target as HTMLInputElement).files?.[0] || null; - if (file) { - logo = file; + function onLogoChange(input: File | string | null) { + if (input == null) return; + + if (typeof input === 'string') { + logo = null; + logoDataURL = input || null; + $inputs.logoUrl!.value = input; + } else { + logo = input; + $inputs.logoUrl && ($inputs.logoUrl.value = ''); const reader = new FileReader(); - reader.onload = (event) => { - logoDataURL = event.target?.result as string; - }; - reader.readAsDataURL(file); + reader.onload = (event) => (logoDataURL = event.target?.result as string); + reader.readAsDataURL(input); } } function resetLogo() { logo = null; logoDataURL = null; + $inputs.logoUrl && ($inputs.logoUrl.value = ''); } function getFederatedIdentityErrors(errors: z.ZodError | undefined) { @@ -173,32 +184,13 @@ bind:checked={$inputs.requiresReauthentication.value} /> -
- -
- {#if logoDataURL} - - {/if} -
- - {#if logoDataURL} - - {/if} -
-
+
+
{#if showAdvancedOptions} diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-image-input.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-image-input.svelte new file mode 100644 index 00000000..61201d3a --- /dev/null +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-image-input.svelte @@ -0,0 +1,44 @@ + + + +
+ {#if logoDataURL} +
+
+ + +
+
+ {/if} +
+
+ +
+
+
diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-list.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-list.svelte index 45bc2468..59353cc3 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-list.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-list.svelte @@ -71,16 +71,21 @@ ? item.allowedUserGroupsCount : m.unrestricted()} - - - + +
+ + +
{/snippet} diff --git a/frontend/src/routes/settings/apps/authorized-oidc-client-card.svelte b/frontend/src/routes/settings/apps/authorized-oidc-client-card.svelte index fdc7483f..198156a8 100644 --- a/frontend/src/routes/settings/apps/authorized-oidc-client-card.svelte +++ b/frontend/src/routes/settings/apps/authorized-oidc-client-card.svelte @@ -38,7 +38,7 @@