mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-20 17:25:43 +03:00
feat: support for url based icons (#840)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
@@ -56,7 +56,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
||||
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
|
||||
}
|
||||
|
||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService)
|
||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -180,6 +181,25 @@ func validateEnvConfig(config *EnvConfigSchema) error {
|
||||
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage)
|
||||
}
|
||||
|
||||
// Validate LOCAL_IPV6_RANGES
|
||||
ranges := strings.Split(config.LocalIPv6Ranges, ",")
|
||||
for _, rangeStr := range ranges {
|
||||
rangeStr = strings.TrimSpace(rangeStr)
|
||||
if rangeStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(rangeStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCAL_IPV6_RANGES '%s': %w", rangeStr, err)
|
||||
}
|
||||
|
||||
if ipNet.IP.To4() != nil {
|
||||
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@ type OidcClientUpdateDto struct {
|
||||
RequiresReauthentication bool `json:"requiresReauthentication"`
|
||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||
LaunchURL *string `json:"launchURL" binding:"omitempty,url"`
|
||||
HasLogo bool `json:"hasLogo"`
|
||||
LogoURL *string `json:"logoUrl"`
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
@@ -54,7 +52,6 @@ type OidcClient struct {
|
||||
CallbackURLs UrlList
|
||||
LogoutCallbackURLs UrlList
|
||||
ImageType *string
|
||||
HasLogo bool `gorm:"-"`
|
||||
IsPublic bool
|
||||
PkceEnabled bool
|
||||
RequiresReauthentication bool
|
||||
@@ -67,6 +64,10 @@ type OidcClient struct {
|
||||
UserAuthorizedOidcClients []UserAuthorizedOidcClient `gorm:"foreignKey:ClientID;references:ID"`
|
||||
}
|
||||
|
||||
func (c OidcClient) HasLogo() bool {
|
||||
return c.ImageType != nil && *c.ImageType != ""
|
||||
}
|
||||
|
||||
type OidcRefreshToken struct {
|
||||
Base
|
||||
|
||||
@@ -89,12 +90,6 @@ func (c OidcRefreshToken) Scopes() []string {
|
||||
return strings.Split(c.Scope, " ")
|
||||
}
|
||||
|
||||
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||
// Compute HasLogo field
|
||||
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
||||
return nil
|
||||
}
|
||||
|
||||
type OidcClientCredentials struct { //nolint:recvcheck
|
||||
FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"`
|
||||
}
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang/v2"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
@@ -26,22 +26,6 @@ type GeoLiteService struct {
|
||||
httpClient *http.Client
|
||||
disableUpdater bool
|
||||
mutex sync.RWMutex
|
||||
localIPv6Ranges []*net.IPNet
|
||||
}
|
||||
|
||||
var localhostIPNets = []*net.IPNet{
|
||||
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
|
||||
{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128
|
||||
}
|
||||
|
||||
var privateLanIPNets = []*net.IPNet{
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
|
||||
}
|
||||
|
||||
var tailscaleIPNets = []*net.IPNet{
|
||||
{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10
|
||||
}
|
||||
|
||||
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
|
||||
@@ -56,67 +40,9 @@ func NewGeoLiteService(httpClient *http.Client) *GeoLiteService {
|
||||
service.disableUpdater = true
|
||||
}
|
||||
|
||||
// Initialize IPv6 local ranges
|
||||
err := service.initializeIPv6LocalRanges()
|
||||
if err != nil {
|
||||
slog.Warn("Failed to initialize IPv6 local ranges", slog.Any("error", err))
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// initializeIPv6LocalRanges parses the LOCAL_IPV6_RANGES environment variable
|
||||
func (s *GeoLiteService) initializeIPv6LocalRanges() error {
|
||||
rangesEnv := common.EnvConfig.LocalIPv6Ranges
|
||||
if rangesEnv == "" {
|
||||
return nil // No local IPv6 ranges configured
|
||||
}
|
||||
|
||||
ranges := strings.Split(rangesEnv, ",")
|
||||
localRanges := make([]*net.IPNet, 0, len(ranges))
|
||||
|
||||
for _, rangeStr := range ranges {
|
||||
rangeStr = strings.TrimSpace(rangeStr)
|
||||
if rangeStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(rangeStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IPv6 range '%s': %w", rangeStr, err)
|
||||
}
|
||||
|
||||
// Ensure it's an IPv6 range
|
||||
if ipNet.IP.To4() != nil {
|
||||
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
|
||||
}
|
||||
|
||||
localRanges = append(localRanges, ipNet)
|
||||
}
|
||||
|
||||
s.localIPv6Ranges = localRanges
|
||||
|
||||
if len(localRanges) > 0 {
|
||||
slog.Info("Initialized IPv6 local ranges", slog.Int("count", len(localRanges)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLocalIPv6 checks if the given IPv6 address is within any of the configured local ranges
|
||||
func (s *GeoLiteService) isLocalIPv6(ip net.IP) bool {
|
||||
if ip.To4() != nil {
|
||||
return false // Not an IPv6 address
|
||||
}
|
||||
|
||||
for _, localRange := range s.localIPv6Ranges {
|
||||
if localRange.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *GeoLiteService) DisableUpdater() bool {
|
||||
return s.disableUpdater
|
||||
}
|
||||
@@ -129,28 +55,19 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
|
||||
|
||||
// Check the IP address against known private IP ranges
|
||||
if ip := net.ParseIP(ipAddress); ip != nil {
|
||||
// Check IPv6 local ranges first
|
||||
if s.isLocalIPv6(ip) {
|
||||
if utils.IsLocalIPv6(ip) {
|
||||
return "Internal Network", "LAN", nil
|
||||
}
|
||||
|
||||
// Check existing IPv4 ranges
|
||||
for _, ipNet := range tailscaleIPNets {
|
||||
if ipNet.Contains(ip) {
|
||||
if utils.IsTailscaleIP(ip) {
|
||||
return "Internal Network", "Tailscale", nil
|
||||
}
|
||||
}
|
||||
for _, ipNet := range privateLanIPNets {
|
||||
if ipNet.Contains(ip) {
|
||||
if utils.IsPrivateIP(ip) {
|
||||
return "Internal Network", "LAN", nil
|
||||
}
|
||||
}
|
||||
for _, ipNet := range localhostIPNets {
|
||||
if ipNet.Contains(ip) {
|
||||
if utils.IsLocalhostIP(ip) {
|
||||
return "Internal Network", "localhost", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
localRanges string
|
||||
testIP string
|
||||
expectedCountry string
|
||||
expectedCity string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6 in local range",
|
||||
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "LAN",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 not in local range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:ffff:000::1",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - second range match",
|
||||
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
||||
testIP: "2001:0db8:abcd:001::1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "LAN",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty local ranges",
|
||||
localRanges: "",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 private address still works",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "192.168.1.1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "LAN",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "::1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "localhost",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
||||
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
|
||||
defer func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
||||
}()
|
||||
|
||||
service := NewGeoLiteService(&http.Client{})
|
||||
|
||||
country, city, err := service.GetLocationByIP(tt.testIP)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil && country != "Internal Network" {
|
||||
t.Errorf("Expected error or internal network classification for external IP")
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedCountry, country)
|
||||
assert.Equal(t, tt.expectedCity, city)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoLiteService_isLocalIPv6(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
localRanges string
|
||||
testIP string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Valid IPv6 in range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Valid IPv6 not in range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:ffff:000::1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv4 address should return false",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "192.168.1.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No ranges configured",
|
||||
localRanges: "",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Edge of range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:abcd:00ff:ffff:ffff:ffff:ffff",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
||||
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
|
||||
defer func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
||||
}()
|
||||
|
||||
service := NewGeoLiteService(&http.Client{})
|
||||
ip := net.ParseIP(tt.testIP)
|
||||
if ip == nil {
|
||||
t.Fatalf("Invalid test IP: %s", tt.testIP)
|
||||
}
|
||||
|
||||
result := service.isLocalIPv6(ip)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expectError bool
|
||||
expectCount int
|
||||
}{
|
||||
{
|
||||
name: "Valid IPv6 ranges",
|
||||
envValue: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
||||
expectError: false,
|
||||
expectCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Empty environment variable",
|
||||
envValue: "",
|
||||
expectError: false,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid CIDR notation",
|
||||
envValue: "2001:0db8:abcd:000::/999",
|
||||
expectError: true,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "IPv4 range in IPv6 env var",
|
||||
envValue: "192.168.1.0/24",
|
||||
expectError: true,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Mixed valid and invalid ranges",
|
||||
envValue: "2001:0db8:abcd:000::/56,invalid-range",
|
||||
expectError: true,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Whitespace handling",
|
||||
envValue: " 2001:0db8:abcd:000::/56 , 2001:0db8:abcd:001::/56 ",
|
||||
expectError: false,
|
||||
expectCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
||||
common.EnvConfig.LocalIPv6Ranges = tt.envValue
|
||||
defer func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
||||
}()
|
||||
|
||||
service := &GeoLiteService{
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
err := service.initializeIPv6LocalRanges()
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Len(t, service.localIPv6Ranges, tt.expectCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
@@ -66,6 +69,7 @@ func NewOidcService(
|
||||
auditLogService *AuditLogService,
|
||||
customClaimService *CustomClaimService,
|
||||
webAuthnService *WebAuthnService,
|
||||
httpClient *http.Client,
|
||||
) (s *OidcService, err error) {
|
||||
s = &OidcService{
|
||||
db: db,
|
||||
@@ -74,6 +78,7 @@ func NewOidcService(
|
||||
auditLogService: auditLogService,
|
||||
customClaimService: customClaimService,
|
||||
webAuthnService: webAuthnService,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
|
||||
@@ -714,6 +719,11 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client := model.OidcClient{
|
||||
Base: model.Base{
|
||||
ID: input.ID,
|
||||
@@ -722,7 +732,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
||||
}
|
||||
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
|
||||
|
||||
err := s.db.
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Create(&client).
|
||||
Error
|
||||
@@ -733,33 +743,11 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if input.LogoURL != nil {
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||
}
|
||||
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
@@ -770,6 +758,36 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() { tx.Rollback() }()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := tx.WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
if input.LogoURL != nil {
|
||||
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientUpdateDto) {
|
||||
// Base fields
|
||||
client.Name = input.Name
|
||||
@@ -883,41 +901,14 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
|
||||
err = s.updateClientLogoType(ctx, tx, clientID, fileType)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if client.ImageType != nil && fileType != *client.ImageType {
|
||||
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||
if err := os.Remove(oldImagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
client.ImageType = &fileType
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||
@@ -941,6 +932,7 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
|
||||
oldImageType := *client.ImageType
|
||||
client.ImageType = nil
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
@@ -1333,7 +1325,7 @@ func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, us
|
||||
Client: dto.OidcClientMetaDataDto{
|
||||
ID: deviceAuth.Client.ID,
|
||||
Name: deviceAuth.Client.Name,
|
||||
HasLogo: deviceAuth.Client.HasLogo,
|
||||
HasLogo: deviceAuth.Client.HasLogo(),
|
||||
},
|
||||
Scope: deviceAuth.Scope,
|
||||
AuthorizationRequired: !hasAuthorizedClient,
|
||||
@@ -1468,7 +1460,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
|
||||
ID: client.ID,
|
||||
Name: client.Name,
|
||||
LaunchURL: client.LaunchURL,
|
||||
HasLogo: client.HasLogo,
|
||||
HasLogo: client.HasLogo(),
|
||||
},
|
||||
LastUsedAt: lastUsedAt,
|
||||
}
|
||||
@@ -1889,3 +1881,87 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
|
||||
|
||||
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
||||
}
|
||||
|
||||
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string) error {
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
r := net.Resolver{}
|
||||
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
||||
if err != nil || len(ips) == 0 {
|
||||
return fmt.Errorf("cannot resolve hostname")
|
||||
}
|
||||
|
||||
// Prevents SSRF by allowing only public IPs
|
||||
for _, addr := range ips {
|
||||
if utils.IsPrivateIP(addr.IP) {
|
||||
return fmt.Errorf("private IP addresses are not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
|
||||
req.Header.Set("Accept", "image/*")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to fetch logo: %s", resp.Status)
|
||||
}
|
||||
|
||||
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
|
||||
if resp.ContentLength > maxLogoSize {
|
||||
return fmt.Errorf("logo is too large")
|
||||
}
|
||||
|
||||
// Prefer extension in path if supported
|
||||
ext := utils.GetFileExtension(u.Path)
|
||||
if ext == "" || utils.GetImageMimeType(ext) == "" {
|
||||
// Otherwise, try to detect from content type
|
||||
ext = utils.GetImageExtensionFromMimeType(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
if ext == "" {
|
||||
return &common.FileTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + ext
|
||||
err = utils.SaveFileStream(io.LimitReader(resp.Body, maxLogoSize+1), imagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.updateClientLogoType(ctx, tx, clientID, ext); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string) error {
|
||||
uploadsDir := common.EnvConfig.UploadPath + "/oidc-client-images"
|
||||
|
||||
var client model.OidcClient
|
||||
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if client.ImageType != nil && *client.ImageType != ext {
|
||||
old := fmt.Sprintf("%s/%s.%s", uploadsDir, client.ID, *client.ImageType)
|
||||
_ = os.Remove(old)
|
||||
}
|
||||
client.ImageType = &ext
|
||||
return tx.WithContext(ctx).Save(&client).Error
|
||||
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -57,6 +58,34 @@ func GetImageMimeType(ext string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func GetImageExtensionFromMimeType(mimeType string) string {
|
||||
// Normalize and strip parameters like `; charset=utf-8`
|
||||
mt := strings.TrimSpace(strings.ToLower(mimeType))
|
||||
if v, _, err := mime.ParseMediaType(mt); err == nil {
|
||||
mt = v
|
||||
}
|
||||
switch mt {
|
||||
case "image/jpeg", "image/jpg":
|
||||
return "jpg"
|
||||
case "image/png":
|
||||
return "png"
|
||||
case "image/svg+xml":
|
||||
return "svg"
|
||||
case "image/x-icon", "image/vnd.microsoft.icon":
|
||||
return "ico"
|
||||
case "image/gif":
|
||||
return "gif"
|
||||
case "image/webp":
|
||||
return "webp"
|
||||
case "image/avif":
|
||||
return "avif"
|
||||
case "image/heic", "image/heif":
|
||||
return "heic"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func CopyEmbeddedFileToDisk(srcFilePath, destFilePath string) error {
|
||||
srcFile, err := resources.FS.Open(srcFilePath)
|
||||
if err != nil {
|
||||
|
||||
87
backend/internal/utils/ip_util.go
Normal file
87
backend/internal/utils/ip_util.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
var localIPv6Ranges []*net.IPNet
|
||||
|
||||
var localhostIPNets = []*net.IPNet{
|
||||
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
|
||||
{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128
|
||||
}
|
||||
|
||||
var privateLanIPNets = []*net.IPNet{
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
|
||||
}
|
||||
|
||||
var tailscaleIPNets = []*net.IPNet{
|
||||
{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10
|
||||
}
|
||||
|
||||
func IsLocalIPv6(ip net.IP) bool {
|
||||
if ip.To4() != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return listContainsIP(localIPv6Ranges, ip)
|
||||
}
|
||||
|
||||
func IsLocalhostIP(ip net.IP) bool {
|
||||
return listContainsIP(localhostIPNets, ip)
|
||||
}
|
||||
|
||||
func IsPrivateLanIP(ip net.IP) bool {
|
||||
if ip.To4() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return listContainsIP(privateLanIPNets, ip)
|
||||
}
|
||||
|
||||
func IsTailscaleIP(ip net.IP) bool {
|
||||
if ip.To4() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return listContainsIP(tailscaleIPNets, ip)
|
||||
}
|
||||
|
||||
func IsPrivateIP(ip net.IP) bool {
|
||||
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
|
||||
}
|
||||
|
||||
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
|
||||
for _, ipNet := range ipNets {
|
||||
if ipNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadLocalIPv6Ranges() {
|
||||
localIPv6Ranges = nil
|
||||
ranges := strings.Split(common.EnvConfig.LocalIPv6Ranges, ",")
|
||||
|
||||
for _, rangeStr := range ranges {
|
||||
rangeStr = strings.TrimSpace(rangeStr)
|
||||
if rangeStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(rangeStr)
|
||||
if err == nil {
|
||||
localIPv6Ranges = append(localIPv6Ranges, ipNet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
loadLocalIPv6Ranges()
|
||||
}
|
||||
159
backend/internal/utils/ip_util_test.go
Normal file
159
backend/internal/utils/ip_util_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
func TestIsLocalhostIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"127.0.0.1", true},
|
||||
{"127.255.255.255", true},
|
||||
{"::1", true},
|
||||
{"192.168.1.1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsLocalhostIP(ip); got != tt.expected {
|
||||
t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateLanIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"10.0.0.1", true},
|
||||
{"172.16.5.4", true},
|
||||
{"192.168.100.200", true},
|
||||
{"8.8.8.8", false},
|
||||
{"::1", false}, // IPv6 should return false
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsPrivateLanIP(ip); got != tt.expected {
|
||||
t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTailscaleIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"100.64.0.1", true},
|
||||
{"100.127.255.254", true},
|
||||
{"8.8.8.8", false},
|
||||
{"::1", false}, // IPv6 should return false
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsTailscaleIP(ip); got != tt.expected {
|
||||
t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalIPv6(t *testing.T) {
|
||||
// Save and restore env config
|
||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||
|
||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8,fc00::/7"
|
||||
localIPv6Ranges = nil // reset
|
||||
loadLocalIPv6Ranges()
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"fd00::1", true},
|
||||
{"fc00::abcd", true},
|
||||
{"::1", false}, // loopback handled separately
|
||||
{"192.168.1.1", false}, // IPv4 should return false
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsLocalIPv6(ip); got != tt.expected {
|
||||
t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
// Save and restore env config
|
||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||
|
||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
|
||||
localIPv6Ranges = nil // reset
|
||||
loadLocalIPv6Ranges()
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"127.0.0.1", true}, // localhost
|
||||
{"192.168.1.1", true}, // private LAN
|
||||
{"100.64.0.1", true}, // Tailscale
|
||||
{"fd00::1", true}, // local IPv6
|
||||
{"8.8.8.8", false}, // public IPv4
|
||||
{"2001:4860:4860::8888", false}, // public IPv6
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsPrivateIP(ip); got != tt.expected {
|
||||
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListContainsIP(t *testing.T) {
|
||||
_, ipNet1, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
_, ipNet2, _ := net.ParseCIDR("192.168.0.0/16")
|
||||
|
||||
list := []*net.IPNet{ipNet1, ipNet2}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"10.1.1.1", true},
|
||||
{"192.168.5.5", true},
|
||||
{"172.16.0.1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := listContainsIP(list, ip); got != tt.expected {
|
||||
t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_LocalIPv6Ranges(t *testing.T) {
|
||||
// Save and restore env config
|
||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||
|
||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
|
||||
localIPv6Ranges = nil
|
||||
loadLocalIPv6Ranges()
|
||||
|
||||
if len(localIPv6Ranges) != 2 {
|
||||
t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges))
|
||||
}
|
||||
}
|
||||
@@ -450,5 +450,7 @@
|
||||
"display_name": "Display Name",
|
||||
"configure_application_images": "Configure Application Images",
|
||||
"ui_config_disabled_info_title": "UI Configuration Disabled",
|
||||
"ui_config_disabled_info_description": "The UI configuration is disabled because the application configuration settings are managed through environment variables. Some settings may not be editable."
|
||||
"ui_config_disabled_info_description": "The UI configuration is disabled because the application configuration settings are managed through environment variables. Some settings may not be editable.",
|
||||
"logo_from_url_description": "Paste a direct image URL (svg, png, webp). Find icons at <link href=\"https://selfh.st/icons\">Selfh.st Icons</link> or <link href=\"https://dashboardicons.com\">Dashboard Icons</link>.",
|
||||
"invalid_url": "Invalid URL"
|
||||
}
|
||||
|
||||
85
frontend/src/lib/components/form/url-file-input.svelte
Normal file
85
frontend/src/lib/components/form/url-file-input.svelte
Normal file
@@ -0,0 +1,85 @@
|
||||
<script lang="ts">
|
||||
import FileInput from '$lib/components/form/file-input.svelte';
|
||||
import FormattedMessage from '$lib/components/formatted-message.svelte';
|
||||
import { Button, buttonVariants } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { Label } from '$lib/components/ui/label';
|
||||
import * as Popover from '$lib/components/ui/popover';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import { cn } from '$lib/utils/style';
|
||||
import { LucideChevronDown } from '@lucide/svelte';
|
||||
|
||||
let {
|
||||
label,
|
||||
accept,
|
||||
onchange
|
||||
}: {
|
||||
label: string;
|
||||
accept?: string;
|
||||
onchange: (file: File | string | null) => void;
|
||||
} = $props();
|
||||
|
||||
let url = $state('');
|
||||
let hasError = $state(false);
|
||||
|
||||
async function handleFileChange(e: Event) {
|
||||
const file = (e.target as HTMLInputElement).files?.[0] || null;
|
||||
url = '';
|
||||
hasError = false;
|
||||
onchange(file);
|
||||
}
|
||||
|
||||
async function handleUrlChange(e: Event) {
|
||||
const url = (e.target as HTMLInputElement).value.trim();
|
||||
if (!url) return;
|
||||
|
||||
try {
|
||||
new URL(url);
|
||||
hasError = false;
|
||||
} catch {
|
||||
hasError = true;
|
||||
return;
|
||||
}
|
||||
|
||||
onchange(url);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex">
|
||||
<FileInput
|
||||
id="logo"
|
||||
variant="secondary"
|
||||
{accept}
|
||||
onchange={handleFileChange}
|
||||
onclick={(e: any) => (e.target.value = '')}
|
||||
>
|
||||
<Button variant="secondary" class="rounded-r-none">
|
||||
{label}
|
||||
</Button>
|
||||
</FileInput>
|
||||
<Popover.Root>
|
||||
<Popover.Trigger
|
||||
class={cn(buttonVariants({ variant: 'secondary' }), 'rounded-l-none border-l')}
|
||||
>
|
||||
<LucideChevronDown class="size-4" /></Popover.Trigger
|
||||
>
|
||||
<Popover.Content class="w-80">
|
||||
<Label for="file-url" class="text-xs">URL</Label>
|
||||
<Input
|
||||
id="file-url"
|
||||
placeholder=""
|
||||
value={url}
|
||||
oninput={(e) => (url = e.currentTarget.value)}
|
||||
onfocusout={handleUrlChange}
|
||||
aria-invalid={hasError}
|
||||
/>
|
||||
{#if hasError}
|
||||
<p class="text-destructive mt-1 text-start text-xs">{m.invalid_url()}</p>
|
||||
{/if}
|
||||
|
||||
<p class="text-muted-foreground mt-2 text-xs">
|
||||
<FormattedMessage m={m.logo_from_url_description()} />
|
||||
</p>
|
||||
</Popover.Content>
|
||||
</Popover.Root>
|
||||
</div>
|
||||
@@ -1,10 +1,25 @@
|
||||
<script lang="ts">
|
||||
import { cn } from '$lib/utils/style';
|
||||
import { LucideImageOff } from '@lucide/svelte';
|
||||
import type { HTMLImgAttributes } from 'svelte/elements';
|
||||
|
||||
let props: HTMLImgAttributes & {} = $props();
|
||||
let error = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
props.src;
|
||||
error = false;
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class={'bg-muted flex items-center justify-center rounded-2xl p-3'}>
|
||||
<img class={cn('size-24 object-contain', props.class)} {...props} />
|
||||
{#if error}
|
||||
<LucideImageOff class={cn('text-muted-foreground p-5', props.class)} />
|
||||
{:else}
|
||||
<img
|
||||
{...props}
|
||||
class={cn('object-contain', props.class)}
|
||||
onerror={() => (error = true)}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -46,7 +46,8 @@ export type OidcClientUpdateWithLogo = OidcClientUpdate & {
|
||||
};
|
||||
|
||||
export type OidcClientCreateWithLogo = OidcClientCreate & {
|
||||
logo: File | null | undefined;
|
||||
logo?: File | null;
|
||||
logoUrl?: string;
|
||||
};
|
||||
|
||||
export type OidcDeviceCodeInfo = {
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
accentColor: z.string()
|
||||
});
|
||||
|
||||
let { inputs, ...form } = $derived(createForm(formSchema, appConfig));
|
||||
let { inputs, ...form } = $derived(createForm(formSchema, updatedAppConfig));
|
||||
|
||||
async function onSubmit() {
|
||||
const data = form.validate();
|
||||
@@ -69,7 +69,6 @@
|
||||
description={m.whether_the_users_should_be_able_to_edit_their_own_account_details()}
|
||||
bind:checked={$inputs.allowOwnAccountEdit.value}
|
||||
/>
|
||||
|
||||
<SwitchWithLabel
|
||||
id="emails-verified"
|
||||
label={m.emails_verified()}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
<script lang="ts">
|
||||
import FileInput from '$lib/components/form/file-input.svelte';
|
||||
import FormInput from '$lib/components/form/form-input.svelte';
|
||||
import SwitchWithLabel from '$lib/components/form/switch-with-label.svelte';
|
||||
import ImageBox from '$lib/components/image-box.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import type {
|
||||
OidcClient,
|
||||
@@ -21,6 +18,7 @@
|
||||
import { z } from 'zod/v4';
|
||||
import FederatedIdentitiesInput from './federated-identities-input.svelte';
|
||||
import OidcCallbackUrlInput from './oidc-callback-url-input.svelte';
|
||||
import OidcClientImageInput from './oidc-client-image-input.svelte';
|
||||
|
||||
let {
|
||||
callback,
|
||||
@@ -31,7 +29,6 @@
|
||||
callback: (client: OidcClientCreateWithLogo | OidcClientUpdateWithLogo) => Promise<boolean>;
|
||||
mode: 'create' | 'update';
|
||||
} = $props();
|
||||
|
||||
let isLoading = $state(false);
|
||||
let showAdvancedOptions = $state(false);
|
||||
let logo = $state<File | null | undefined>();
|
||||
@@ -50,7 +47,8 @@
|
||||
launchURL: existingClient?.launchURL || '',
|
||||
credentials: {
|
||||
federatedIdentities: existingClient?.credentials?.federatedIdentities || []
|
||||
}
|
||||
},
|
||||
logoUrl: ''
|
||||
};
|
||||
|
||||
const formSchema = z.object({
|
||||
@@ -71,6 +69,7 @@
|
||||
pkceEnabled: z.boolean(),
|
||||
requiresReauthentication: z.boolean(),
|
||||
launchURL: optionalUrl,
|
||||
logoUrl: optionalUrl,
|
||||
credentials: z.object({
|
||||
federatedIdentities: z.array(
|
||||
z.object({
|
||||
@@ -90,30 +89,42 @@
|
||||
const data = form.validate();
|
||||
if (!data) return;
|
||||
isLoading = true;
|
||||
|
||||
const success = await callback({
|
||||
...data,
|
||||
logo
|
||||
logo: $inputs.logoUrl?.value ? null : logo,
|
||||
logoUrl: $inputs.logoUrl?.value
|
||||
});
|
||||
// Reset form if client was successfully created
|
||||
|
||||
const hasLogo = logo != null || !!$inputs.logoUrl?.value;
|
||||
if (success && existingClient && hasLogo) {
|
||||
logoDataURL = cachedOidcClientLogo.getUrl(existingClient.id);
|
||||
}
|
||||
|
||||
if (success && !existingClient) form.reset();
|
||||
isLoading = false;
|
||||
}
|
||||
|
||||
function onLogoChange(e: Event) {
|
||||
const file = (e.target as HTMLInputElement).files?.[0] || null;
|
||||
if (file) {
|
||||
logo = file;
|
||||
function onLogoChange(input: File | string | null) {
|
||||
if (input == null) return;
|
||||
|
||||
if (typeof input === 'string') {
|
||||
logo = null;
|
||||
logoDataURL = input || null;
|
||||
$inputs.logoUrl!.value = input;
|
||||
} else {
|
||||
logo = input;
|
||||
$inputs.logoUrl && ($inputs.logoUrl.value = '');
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
logoDataURL = event.target?.result as string;
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
reader.onload = (event) => (logoDataURL = event.target?.result as string);
|
||||
reader.readAsDataURL(input);
|
||||
}
|
||||
}
|
||||
|
||||
function resetLogo() {
|
||||
logo = null;
|
||||
logoDataURL = null;
|
||||
$inputs.logoUrl && ($inputs.logoUrl.value = '');
|
||||
}
|
||||
|
||||
function getFederatedIdentityErrors(errors: z.ZodError<any> | undefined) {
|
||||
@@ -173,32 +184,13 @@
|
||||
bind:checked={$inputs.requiresReauthentication.value}
|
||||
/>
|
||||
</div>
|
||||
<div class="mt-8">
|
||||
<Label for="logo">{m.logo()}</Label>
|
||||
<div class="mt-2 flex items-end gap-3">
|
||||
{#if logoDataURL}
|
||||
<ImageBox
|
||||
class="size-24"
|
||||
src={logoDataURL}
|
||||
alt={m.name_logo({ name: $inputs.name.value })}
|
||||
<div class="mt-7">
|
||||
<OidcClientImageInput
|
||||
{logoDataURL}
|
||||
{resetLogo}
|
||||
clientName={$inputs.name.value}
|
||||
{onLogoChange}
|
||||
/>
|
||||
{/if}
|
||||
<div class="flex flex-col gap-2">
|
||||
<FileInput
|
||||
id="logo"
|
||||
variant="secondary"
|
||||
accept="image/png, image/jpeg, image/svg+xml, image/webp, image/avif, image/heic"
|
||||
onchange={onLogoChange}
|
||||
>
|
||||
<Button variant="secondary">
|
||||
{logoDataURL ? m.change_logo() : m.upload_logo()}
|
||||
</Button>
|
||||
</FileInput>
|
||||
{#if logoDataURL}
|
||||
<Button variant="outline" onclick={resetLogo}>{m.remove_logo()}</Button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if showAdvancedOptions}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
<script lang="ts">
|
||||
import UrlFileInput from '$lib/components/form/url-file-input.svelte';
|
||||
import ImageBox from '$lib/components/image-box.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Label } from '$lib/components/ui/label';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import { LucideX } from '@lucide/svelte';
|
||||
|
||||
let {
|
||||
logoDataURL,
|
||||
clientName,
|
||||
resetLogo,
|
||||
onLogoChange
|
||||
}: {
|
||||
logoDataURL: string | null;
|
||||
clientName: string;
|
||||
resetLogo: () => void;
|
||||
onLogoChange: (file: File | string | null) => void;
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<Label for="logo">{m.logo()}</Label>
|
||||
<div class="flex items-end gap-4">
|
||||
{#if logoDataURL}
|
||||
<div class="flex items-start gap-4">
|
||||
<div class="relative shrink-0">
|
||||
<ImageBox class="size-24" src={logoDataURL} alt={m.name_logo({ name: clientName })} />
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="icon"
|
||||
onclick={resetLogo}
|
||||
class="absolute -top-2 -right-2 size-6 rounded-full shadow-md"
|
||||
>
|
||||
<LucideX class="size-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
<div class="flex flex-col gap-3">
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<UrlFileInput label={m.upload_logo()} accept="image/*" onchange={onLogoChange} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -71,16 +71,21 @@
|
||||
? item.allowedUserGroupsCount
|
||||
: m.unrestricted()}</Table.Cell
|
||||
>
|
||||
<Table.Cell class="flex justify-end gap-1">
|
||||
<Table.Cell class="align-middle">
|
||||
<div class="flex justify-end gap-1">
|
||||
<Button
|
||||
href="/settings/admin/oidc-clients/{item.id}"
|
||||
size="sm"
|
||||
variant="outline"
|
||||
aria-label={m.edit()}><LucidePencil class="size-3 " /></Button
|
||||
>
|
||||
<Button onclick={() => deleteClient(item)} size="sm" variant="outline" aria-label={m.delete()}
|
||||
><LucideTrash class="size-3 text-red-500" /></Button
|
||||
<Button
|
||||
onclick={() => deleteClient(item)}
|
||||
size="sm"
|
||||
variant="outline"
|
||||
aria-label={m.delete()}><LucideTrash class="size-3 text-red-500" /></Button
|
||||
>
|
||||
</div>
|
||||
</Table.Cell>
|
||||
{/snippet}
|
||||
</AdvancedTable>
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
<div class="flex gap-3">
|
||||
<div class="aspect-square h-[56px]">
|
||||
<ImageBox
|
||||
class="grow rounded-lg object-contain"
|
||||
class="h-8 w-8 grow rounded-lg object-contain"
|
||||
src={client.hasLogo
|
||||
? cachedOidcClientLogo.getUrl(client.id)
|
||||
: cachedApplicationLogo.getUrl(isLightMode)}
|
||||
|
||||
Reference in New Issue
Block a user