feat: support for url based icons (#840)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-09-29 10:07:55 -05:00
committed by GitHub
parent 47bd5ba1ba
commit 6bdf5fa37a
19 changed files with 650 additions and 442 deletions

View File

@@ -56,7 +56,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
}
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService)
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"log/slog"
"net"
"net/url"
"os"
"reflect"
@@ -180,6 +181,25 @@ func validateEnvConfig(config *EnvConfigSchema) error {
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage)
}
// Validate LOCAL_IPV6_RANGES
ranges := strings.Split(config.LocalIPv6Ranges, ",")
for _, rangeStr := range ranges {
rangeStr = strings.TrimSpace(rangeStr)
if rangeStr == "" {
continue
}
_, ipNet, err := net.ParseCIDR(rangeStr)
if err != nil {
return fmt.Errorf("invalid LOCAL_IPV6_RANGES '%s': %w", rangeStr, err)
}
if ipNet.IP.To4() != nil {
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
}
}
return nil
}

View File

@@ -38,6 +38,8 @@ type OidcClientUpdateDto struct {
RequiresReauthentication bool `json:"requiresReauthentication"`
Credentials OidcClientCredentialsDto `json:"credentials"`
LaunchURL *string `json:"launchURL" binding:"omitempty,url"`
HasLogo bool `json:"hasLogo"`
LogoURL *string `json:"logoUrl"`
}
type OidcClientCreateDto struct {

View File

@@ -6,8 +6,6 @@ import (
"fmt"
"strings"
"gorm.io/gorm"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
@@ -54,7 +52,6 @@ type OidcClient struct {
CallbackURLs UrlList
LogoutCallbackURLs UrlList
ImageType *string
HasLogo bool `gorm:"-"`
IsPublic bool
PkceEnabled bool
RequiresReauthentication bool
@@ -67,6 +64,10 @@ type OidcClient struct {
UserAuthorizedOidcClients []UserAuthorizedOidcClient `gorm:"foreignKey:ClientID;references:ID"`
}
func (c OidcClient) HasLogo() bool {
return c.ImageType != nil && *c.ImageType != ""
}
type OidcRefreshToken struct {
Base
@@ -89,12 +90,6 @@ func (c OidcRefreshToken) Scopes() []string {
return strings.Split(c.Scope, " ")
}
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
// Compute HasLogo field
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
return nil
}
type OidcClientCredentials struct { //nolint:recvcheck
FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"`
}

View File

@@ -13,11 +13,11 @@ import (
"net/netip"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/oschwald/maxminddb-golang/v2"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
@@ -26,22 +26,6 @@ type GeoLiteService struct {
httpClient *http.Client
disableUpdater bool
mutex sync.RWMutex
localIPv6Ranges []*net.IPNet
}
var localhostIPNets = []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128
}
var privateLanIPNets = []*net.IPNet{
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
}
var tailscaleIPNets = []*net.IPNet{
{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10
}
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
@@ -56,67 +40,9 @@ func NewGeoLiteService(httpClient *http.Client) *GeoLiteService {
service.disableUpdater = true
}
// Initialize IPv6 local ranges
err := service.initializeIPv6LocalRanges()
if err != nil {
slog.Warn("Failed to initialize IPv6 local ranges", slog.Any("error", err))
}
return service
}
// initializeIPv6LocalRanges parses the LOCAL_IPV6_RANGES environment variable
func (s *GeoLiteService) initializeIPv6LocalRanges() error {
rangesEnv := common.EnvConfig.LocalIPv6Ranges
if rangesEnv == "" {
return nil // No local IPv6 ranges configured
}
ranges := strings.Split(rangesEnv, ",")
localRanges := make([]*net.IPNet, 0, len(ranges))
for _, rangeStr := range ranges {
rangeStr = strings.TrimSpace(rangeStr)
if rangeStr == "" {
continue
}
_, ipNet, err := net.ParseCIDR(rangeStr)
if err != nil {
return fmt.Errorf("invalid IPv6 range '%s': %w", rangeStr, err)
}
// Ensure it's an IPv6 range
if ipNet.IP.To4() != nil {
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
}
localRanges = append(localRanges, ipNet)
}
s.localIPv6Ranges = localRanges
if len(localRanges) > 0 {
slog.Info("Initialized IPv6 local ranges", slog.Int("count", len(localRanges)))
}
return nil
}
// isLocalIPv6 checks if the given IPv6 address is within any of the configured local ranges
func (s *GeoLiteService) isLocalIPv6(ip net.IP) bool {
if ip.To4() != nil {
return false // Not an IPv6 address
}
for _, localRange := range s.localIPv6Ranges {
if localRange.Contains(ip) {
return true
}
}
return false
}
func (s *GeoLiteService) DisableUpdater() bool {
return s.disableUpdater
}
@@ -129,28 +55,19 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
// Check the IP address against known private IP ranges
if ip := net.ParseIP(ipAddress); ip != nil {
// Check IPv6 local ranges first
if s.isLocalIPv6(ip) {
if utils.IsLocalIPv6(ip) {
return "Internal Network", "LAN", nil
}
// Check existing IPv4 ranges
for _, ipNet := range tailscaleIPNets {
if ipNet.Contains(ip) {
if utils.IsTailscaleIP(ip) {
return "Internal Network", "Tailscale", nil
}
}
for _, ipNet := range privateLanIPNets {
if ipNet.Contains(ip) {
if utils.IsPrivateIP(ip) {
return "Internal Network", "LAN", nil
}
}
for _, ipNet := range localhostIPNets {
if ipNet.Contains(ip) {
if utils.IsLocalhostIP(ip) {
return "Internal Network", "localhost", nil
}
}
}
addr, err := netip.ParseAddr(ipAddress)
if err != nil {

View File

@@ -1,220 +0,0 @@
package service
import (
"net"
"net/http"
"testing"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
tests := []struct {
name string
localRanges string
testIP string
expectedCountry string
expectedCity string
expectError bool
}{
{
name: "IPv6 in local range",
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
testIP: "2001:0db8:abcd:000::1",
expectedCountry: "Internal Network",
expectedCity: "LAN",
expectError: false,
},
{
name: "IPv6 not in local range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:ffff:000::1",
expectError: true,
},
{
name: "Multiple ranges - second range match",
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
testIP: "2001:0db8:abcd:001::1",
expectedCountry: "Internal Network",
expectedCity: "LAN",
expectError: false,
},
{
name: "Empty local ranges",
localRanges: "",
testIP: "2001:0db8:abcd:000::1",
expectError: true,
},
{
name: "IPv4 private address still works",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "192.168.1.1",
expectedCountry: "Internal Network",
expectedCity: "LAN",
expectError: false,
},
{
name: "IPv6 loopback",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "::1",
expectedCountry: "Internal Network",
expectedCity: "localhost",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalConfig := common.EnvConfig.LocalIPv6Ranges
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
defer func() {
common.EnvConfig.LocalIPv6Ranges = originalConfig
}()
service := NewGeoLiteService(&http.Client{})
country, city, err := service.GetLocationByIP(tt.testIP)
if tt.expectError {
if err == nil && country != "Internal Network" {
t.Errorf("Expected error or internal network classification for external IP")
}
} else {
require.NoError(t, err)
assert.Equal(t, tt.expectedCountry, country)
assert.Equal(t, tt.expectedCity, city)
}
})
}
}
func TestGeoLiteService_isLocalIPv6(t *testing.T) {
tests := []struct {
name string
localRanges string
testIP string
expected bool
}{
{
name: "Valid IPv6 in range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:abcd:000::1",
expected: true,
},
{
name: "Valid IPv6 not in range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:ffff:000::1",
expected: false,
},
{
name: "IPv4 address should return false",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "192.168.1.1",
expected: false,
},
{
name: "No ranges configured",
localRanges: "",
testIP: "2001:0db8:abcd:000::1",
expected: false,
},
{
name: "Edge of range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:abcd:00ff:ffff:ffff:ffff:ffff",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalConfig := common.EnvConfig.LocalIPv6Ranges
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
defer func() {
common.EnvConfig.LocalIPv6Ranges = originalConfig
}()
service := NewGeoLiteService(&http.Client{})
ip := net.ParseIP(tt.testIP)
if ip == nil {
t.Fatalf("Invalid test IP: %s", tt.testIP)
}
result := service.isLocalIPv6(ip)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) {
tests := []struct {
name string
envValue string
expectError bool
expectCount int
}{
{
name: "Valid IPv6 ranges",
envValue: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
expectError: false,
expectCount: 2,
},
{
name: "Empty environment variable",
envValue: "",
expectError: false,
expectCount: 0,
},
{
name: "Invalid CIDR notation",
envValue: "2001:0db8:abcd:000::/999",
expectError: true,
expectCount: 0,
},
{
name: "IPv4 range in IPv6 env var",
envValue: "192.168.1.0/24",
expectError: true,
expectCount: 0,
},
{
name: "Mixed valid and invalid ranges",
envValue: "2001:0db8:abcd:000::/56,invalid-range",
expectError: true,
expectCount: 0,
},
{
name: "Whitespace handling",
envValue: " 2001:0db8:abcd:000::/56 , 2001:0db8:abcd:001::/56 ",
expectError: false,
expectCount: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalConfig := common.EnvConfig.LocalIPv6Ranges
common.EnvConfig.LocalIPv6Ranges = tt.envValue
defer func() {
common.EnvConfig.LocalIPv6Ranges = originalConfig
}()
service := &GeoLiteService{
httpClient: &http.Client{},
}
err := service.initializeIPv6LocalRanges()
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Len(t, service.localIPv6Ranges, tt.expectCount)
})
}
}

View File

@@ -8,9 +8,12 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net"
"net/http"
"net/url"
"os"
"regexp"
"slices"
@@ -66,6 +69,7 @@ func NewOidcService(
auditLogService *AuditLogService,
customClaimService *CustomClaimService,
webAuthnService *WebAuthnService,
httpClient *http.Client,
) (s *OidcService, err error) {
s = &OidcService{
db: db,
@@ -74,6 +78,7 @@ func NewOidcService(
auditLogService: auditLogService,
customClaimService: customClaimService,
webAuthnService: webAuthnService,
httpClient: httpClient,
}
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
@@ -714,6 +719,11 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
}
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
client := model.OidcClient{
Base: model.Base{
ID: input.ID,
@@ -722,7 +732,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
}
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
err := s.db.
err := tx.
WithContext(ctx).
Create(&client).
Error
@@ -733,33 +743,11 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient
err := tx.
WithContext(ctx).
Preload("CreatedBy").
First(&client, "id = ?", clientID).
Error
if input.LogoURL != nil {
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
if err != nil {
return model.OidcClient{}, err
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
}
updateOIDCClientModelFromDto(&client, &input)
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
err = tx.Commit().Error
@@ -770,6 +758,36 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return client, nil
}
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() { tx.Rollback() }()
var client model.OidcClient
if err := tx.WithContext(ctx).
Preload("CreatedBy").
First(&client, "id = ?", clientID).Error; err != nil {
return model.OidcClient{}, err
}
updateOIDCClientModelFromDto(&client, &input)
if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
return model.OidcClient{}, err
}
if input.LogoURL != nil {
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
if err != nil {
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
}
}
if err := tx.Commit().Error; err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientUpdateDto) {
// Base fields
client.Name = input.Name
@@ -883,41 +901,14 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
}
tx := s.db.Begin()
defer func() {
err = s.updateClientLogoType(ctx, tx, clientID, fileType)
if err != nil {
tx.Rollback()
}()
var client model.OidcClient
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err
}
if client.ImageType != nil && fileType != *client.ImageType {
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
if err := os.Remove(oldImagePath); err != nil {
return err
}
}
client.ImageType = &fileType
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
err = tx.Commit().Error
if err != nil {
return err
}
return nil
return tx.Commit().Error
}
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
@@ -941,6 +932,7 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
oldImageType := *client.ImageType
client.ImageType = nil
err = tx.
WithContext(ctx).
Save(&client).
@@ -1333,7 +1325,7 @@ func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, us
Client: dto.OidcClientMetaDataDto{
ID: deviceAuth.Client.ID,
Name: deviceAuth.Client.Name,
HasLogo: deviceAuth.Client.HasLogo,
HasLogo: deviceAuth.Client.HasLogo(),
},
Scope: deviceAuth.Scope,
AuthorizationRequired: !hasAuthorizedClient,
@@ -1468,7 +1460,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
ID: client.ID,
Name: client.Name,
LaunchURL: client.LaunchURL,
HasLogo: client.HasLogo,
HasLogo: client.HasLogo(),
},
LastUsedAt: lastUsedAt,
}
@@ -1889,3 +1881,87 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
return s.IsUserGroupAllowedToAuthorize(user, client), nil
}
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string) error {
u, err := url.Parse(raw)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
defer cancel()
r := net.Resolver{}
ips, err := r.LookupIPAddr(ctx, u.Hostname())
if err != nil || len(ips) == 0 {
return fmt.Errorf("cannot resolve hostname")
}
// Prevents SSRF by allowing only public IPs
for _, addr := range ips {
if utils.IsPrivateIP(addr.IP) {
return fmt.Errorf("private IP addresses are not allowed")
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
if err != nil {
return err
}
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
req.Header.Set("Accept", "image/*")
resp, err := s.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to fetch logo: %s", resp.Status)
}
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
if resp.ContentLength > maxLogoSize {
return fmt.Errorf("logo is too large")
}
// Prefer extension in path if supported
ext := utils.GetFileExtension(u.Path)
if ext == "" || utils.GetImageMimeType(ext) == "" {
// Otherwise, try to detect from content type
ext = utils.GetImageExtensionFromMimeType(resp.Header.Get("Content-Type"))
}
if ext == "" {
return &common.FileTypeNotSupportedError{}
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + ext
err = utils.SaveFileStream(io.LimitReader(resp.Body, maxLogoSize+1), imagePath)
if err != nil {
return err
}
if err := s.updateClientLogoType(ctx, tx, clientID, ext); err != nil {
return err
}
return nil
}
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string) error {
uploadsDir := common.EnvConfig.UploadPath + "/oidc-client-images"
var client model.OidcClient
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if client.ImageType != nil && *client.ImageType != ext {
old := fmt.Sprintf("%s/%s.%s", uploadsDir, client.ID, *client.ImageType)
_ = os.Remove(old)
}
client.ImageType = &ext
return tx.WithContext(ctx).Save(&client).Error
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"os"
"path/filepath"
@@ -57,6 +58,34 @@ func GetImageMimeType(ext string) string {
}
}
func GetImageExtensionFromMimeType(mimeType string) string {
// Normalize and strip parameters like `; charset=utf-8`
mt := strings.TrimSpace(strings.ToLower(mimeType))
if v, _, err := mime.ParseMediaType(mt); err == nil {
mt = v
}
switch mt {
case "image/jpeg", "image/jpg":
return "jpg"
case "image/png":
return "png"
case "image/svg+xml":
return "svg"
case "image/x-icon", "image/vnd.microsoft.icon":
return "ico"
case "image/gif":
return "gif"
case "image/webp":
return "webp"
case "image/avif":
return "avif"
case "image/heic", "image/heif":
return "heic"
default:
return ""
}
}
func CopyEmbeddedFileToDisk(srcFilePath, destFilePath string) error {
srcFile, err := resources.FS.Open(srcFilePath)
if err != nil {

View File

@@ -0,0 +1,87 @@
package utils
import (
"net"
"strings"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
var localIPv6Ranges []*net.IPNet
var localhostIPNets = []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128
}
var privateLanIPNets = []*net.IPNet{
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
}
var tailscaleIPNets = []*net.IPNet{
{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10
}
func IsLocalIPv6(ip net.IP) bool {
if ip.To4() != nil {
return false
}
return listContainsIP(localIPv6Ranges, ip)
}
func IsLocalhostIP(ip net.IP) bool {
return listContainsIP(localhostIPNets, ip)
}
func IsPrivateLanIP(ip net.IP) bool {
if ip.To4() == nil {
return false
}
return listContainsIP(privateLanIPNets, ip)
}
func IsTailscaleIP(ip net.IP) bool {
if ip.To4() == nil {
return false
}
return listContainsIP(tailscaleIPNets, ip)
}
func IsPrivateIP(ip net.IP) bool {
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
}
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
for _, ipNet := range ipNets {
if ipNet.Contains(ip) {
return true
}
}
return false
}
func loadLocalIPv6Ranges() {
localIPv6Ranges = nil
ranges := strings.Split(common.EnvConfig.LocalIPv6Ranges, ",")
for _, rangeStr := range ranges {
rangeStr = strings.TrimSpace(rangeStr)
if rangeStr == "" {
continue
}
_, ipNet, err := net.ParseCIDR(rangeStr)
if err == nil {
localIPv6Ranges = append(localIPv6Ranges, ipNet)
}
}
}
func init() {
loadLocalIPv6Ranges()
}

View File

@@ -0,0 +1,159 @@
package utils
import (
"net"
"testing"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
func TestIsLocalhostIP(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"127.0.0.1", true},
{"127.255.255.255", true},
{"::1", true},
{"192.168.1.1", false},
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsLocalhostIP(ip); got != tt.expected {
t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected)
}
}
}
func TestIsPrivateLanIP(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"10.0.0.1", true},
{"172.16.5.4", true},
{"192.168.100.200", true},
{"8.8.8.8", false},
{"::1", false}, // IPv6 should return false
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsPrivateLanIP(ip); got != tt.expected {
t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected)
}
}
}
func TestIsTailscaleIP(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"100.64.0.1", true},
{"100.127.255.254", true},
{"8.8.8.8", false},
{"::1", false}, // IPv6 should return false
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsTailscaleIP(ip); got != tt.expected {
t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected)
}
}
}
func TestIsLocalIPv6(t *testing.T) {
// Save and restore env config
origRanges := common.EnvConfig.LocalIPv6Ranges
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
common.EnvConfig.LocalIPv6Ranges = "fd00::/8,fc00::/7"
localIPv6Ranges = nil // reset
loadLocalIPv6Ranges()
tests := []struct {
ip string
expected bool
}{
{"fd00::1", true},
{"fc00::abcd", true},
{"::1", false}, // loopback handled separately
{"192.168.1.1", false}, // IPv4 should return false
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsLocalIPv6(ip); got != tt.expected {
t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected)
}
}
}
func TestIsPrivateIP(t *testing.T) {
// Save and restore env config
origRanges := common.EnvConfig.LocalIPv6Ranges
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
localIPv6Ranges = nil // reset
loadLocalIPv6Ranges()
tests := []struct {
ip string
expected bool
}{
{"127.0.0.1", true}, // localhost
{"192.168.1.1", true}, // private LAN
{"100.64.0.1", true}, // Tailscale
{"fd00::1", true}, // local IPv6
{"8.8.8.8", false}, // public IPv4
{"2001:4860:4860::8888", false}, // public IPv6
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsPrivateIP(ip); got != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected)
}
}
}
func TestListContainsIP(t *testing.T) {
_, ipNet1, _ := net.ParseCIDR("10.0.0.0/8")
_, ipNet2, _ := net.ParseCIDR("192.168.0.0/16")
list := []*net.IPNet{ipNet1, ipNet2}
tests := []struct {
ip string
expected bool
}{
{"10.1.1.1", true},
{"192.168.5.5", true},
{"172.16.0.1", false},
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := listContainsIP(list, ip); got != tt.expected {
t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected)
}
}
}
func TestInit_LocalIPv6Ranges(t *testing.T) {
// Save and restore env config
origRanges := common.EnvConfig.LocalIPv6Ranges
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
localIPv6Ranges = nil
loadLocalIPv6Ranges()
if len(localIPv6Ranges) != 2 {
t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges))
}
}

View File

@@ -450,5 +450,7 @@
"display_name": "Display Name",
"configure_application_images": "Configure Application Images",
"ui_config_disabled_info_title": "UI Configuration Disabled",
"ui_config_disabled_info_description": "The UI configuration is disabled because the application configuration settings are managed through environment variables. Some settings may not be editable."
"ui_config_disabled_info_description": "The UI configuration is disabled because the application configuration settings are managed through environment variables. Some settings may not be editable.",
"logo_from_url_description": "Paste a direct image URL (svg, png, webp). Find icons at <link href=\"https://selfh.st/icons\">Selfh.st Icons</link> or <link href=\"https://dashboardicons.com\">Dashboard Icons</link>.",
"invalid_url": "Invalid URL"
}

View File

@@ -0,0 +1,85 @@
<script lang="ts">
import FileInput from '$lib/components/form/file-input.svelte';
import FormattedMessage from '$lib/components/formatted-message.svelte';
import { Button, buttonVariants } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { Label } from '$lib/components/ui/label';
import * as Popover from '$lib/components/ui/popover';
import { m } from '$lib/paraglide/messages';
import { cn } from '$lib/utils/style';
import { LucideChevronDown } from '@lucide/svelte';
let {
label,
accept,
onchange
}: {
label: string;
accept?: string;
onchange: (file: File | string | null) => void;
} = $props();
let url = $state('');
let hasError = $state(false);
async function handleFileChange(e: Event) {
const file = (e.target as HTMLInputElement).files?.[0] || null;
url = '';
hasError = false;
onchange(file);
}
async function handleUrlChange(e: Event) {
const url = (e.target as HTMLInputElement).value.trim();
if (!url) return;
try {
new URL(url);
hasError = false;
} catch {
hasError = true;
return;
}
onchange(url);
}
</script>
<div class="flex">
<FileInput
id="logo"
variant="secondary"
{accept}
onchange={handleFileChange}
onclick={(e: any) => (e.target.value = '')}
>
<Button variant="secondary" class="rounded-r-none">
{label}
</Button>
</FileInput>
<Popover.Root>
<Popover.Trigger
class={cn(buttonVariants({ variant: 'secondary' }), 'rounded-l-none border-l')}
>
<LucideChevronDown class="size-4" /></Popover.Trigger
>
<Popover.Content class="w-80">
<Label for="file-url" class="text-xs">URL</Label>
<Input
id="file-url"
placeholder=""
value={url}
oninput={(e) => (url = e.currentTarget.value)}
onfocusout={handleUrlChange}
aria-invalid={hasError}
/>
{#if hasError}
<p class="text-destructive mt-1 text-start text-xs">{m.invalid_url()}</p>
{/if}
<p class="text-muted-foreground mt-2 text-xs">
<FormattedMessage m={m.logo_from_url_description()} />
</p>
</Popover.Content>
</Popover.Root>
</div>

View File

@@ -1,10 +1,25 @@
<script lang="ts">
import { cn } from '$lib/utils/style';
import { LucideImageOff } from '@lucide/svelte';
import type { HTMLImgAttributes } from 'svelte/elements';
let props: HTMLImgAttributes & {} = $props();
let error = $state(false);
$effect(() => {
props.src;
error = false;
});
</script>
<div class={'bg-muted flex items-center justify-center rounded-2xl p-3'}>
<img class={cn('size-24 object-contain', props.class)} {...props} />
{#if error}
<LucideImageOff class={cn('text-muted-foreground p-5', props.class)} />
{:else}
<img
{...props}
class={cn('object-contain', props.class)}
onerror={() => (error = true)}
/>
{/if}
</div>

View File

@@ -46,7 +46,8 @@ export type OidcClientUpdateWithLogo = OidcClientUpdate & {
};
export type OidcClientCreateWithLogo = OidcClientCreate & {
logo: File | null | undefined;
logo?: File | null;
logoUrl?: string;
};
export type OidcDeviceCodeInfo = {

View File

@@ -41,7 +41,7 @@
accentColor: z.string()
});
let { inputs, ...form } = $derived(createForm(formSchema, appConfig));
let { inputs, ...form } = $derived(createForm(formSchema, updatedAppConfig));
async function onSubmit() {
const data = form.validate();
@@ -69,7 +69,6 @@
description={m.whether_the_users_should_be_able_to_edit_their_own_account_details()}
bind:checked={$inputs.allowOwnAccountEdit.value}
/>
<SwitchWithLabel
id="emails-verified"
label={m.emails_verified()}

View File

@@ -1,10 +1,7 @@
<script lang="ts">
import FileInput from '$lib/components/form/file-input.svelte';
import FormInput from '$lib/components/form/form-input.svelte';
import SwitchWithLabel from '$lib/components/form/switch-with-label.svelte';
import ImageBox from '$lib/components/image-box.svelte';
import { Button } from '$lib/components/ui/button';
import Label from '$lib/components/ui/label/label.svelte';
import { m } from '$lib/paraglide/messages';
import type {
OidcClient,
@@ -21,6 +18,7 @@
import { z } from 'zod/v4';
import FederatedIdentitiesInput from './federated-identities-input.svelte';
import OidcCallbackUrlInput from './oidc-callback-url-input.svelte';
import OidcClientImageInput from './oidc-client-image-input.svelte';
let {
callback,
@@ -31,7 +29,6 @@
callback: (client: OidcClientCreateWithLogo | OidcClientUpdateWithLogo) => Promise<boolean>;
mode: 'create' | 'update';
} = $props();
let isLoading = $state(false);
let showAdvancedOptions = $state(false);
let logo = $state<File | null | undefined>();
@@ -50,7 +47,8 @@
launchURL: existingClient?.launchURL || '',
credentials: {
federatedIdentities: existingClient?.credentials?.federatedIdentities || []
}
},
logoUrl: ''
};
const formSchema = z.object({
@@ -71,6 +69,7 @@
pkceEnabled: z.boolean(),
requiresReauthentication: z.boolean(),
launchURL: optionalUrl,
logoUrl: optionalUrl,
credentials: z.object({
federatedIdentities: z.array(
z.object({
@@ -90,30 +89,42 @@
const data = form.validate();
if (!data) return;
isLoading = true;
const success = await callback({
...data,
logo
logo: $inputs.logoUrl?.value ? null : logo,
logoUrl: $inputs.logoUrl?.value
});
// Reset form if client was successfully created
const hasLogo = logo != null || !!$inputs.logoUrl?.value;
if (success && existingClient && hasLogo) {
logoDataURL = cachedOidcClientLogo.getUrl(existingClient.id);
}
if (success && !existingClient) form.reset();
isLoading = false;
}
function onLogoChange(e: Event) {
const file = (e.target as HTMLInputElement).files?.[0] || null;
if (file) {
logo = file;
function onLogoChange(input: File | string | null) {
if (input == null) return;
if (typeof input === 'string') {
logo = null;
logoDataURL = input || null;
$inputs.logoUrl!.value = input;
} else {
logo = input;
$inputs.logoUrl && ($inputs.logoUrl.value = '');
const reader = new FileReader();
reader.onload = (event) => {
logoDataURL = event.target?.result as string;
};
reader.readAsDataURL(file);
reader.onload = (event) => (logoDataURL = event.target?.result as string);
reader.readAsDataURL(input);
}
}
function resetLogo() {
logo = null;
logoDataURL = null;
$inputs.logoUrl && ($inputs.logoUrl.value = '');
}
function getFederatedIdentityErrors(errors: z.ZodError<any> | undefined) {
@@ -173,32 +184,13 @@
bind:checked={$inputs.requiresReauthentication.value}
/>
</div>
<div class="mt-8">
<Label for="logo">{m.logo()}</Label>
<div class="mt-2 flex items-end gap-3">
{#if logoDataURL}
<ImageBox
class="size-24"
src={logoDataURL}
alt={m.name_logo({ name: $inputs.name.value })}
<div class="mt-7">
<OidcClientImageInput
{logoDataURL}
{resetLogo}
clientName={$inputs.name.value}
{onLogoChange}
/>
{/if}
<div class="flex flex-col gap-2">
<FileInput
id="logo"
variant="secondary"
accept="image/png, image/jpeg, image/svg+xml, image/webp, image/avif, image/heic"
onchange={onLogoChange}
>
<Button variant="secondary">
{logoDataURL ? m.change_logo() : m.upload_logo()}
</Button>
</FileInput>
{#if logoDataURL}
<Button variant="outline" onclick={resetLogo}>{m.remove_logo()}</Button>
{/if}
</div>
</div>
</div>
{#if showAdvancedOptions}

View File

@@ -0,0 +1,44 @@
<script lang="ts">
import UrlFileInput from '$lib/components/form/url-file-input.svelte';
import ImageBox from '$lib/components/image-box.svelte';
import { Button } from '$lib/components/ui/button';
import { Label } from '$lib/components/ui/label';
import { m } from '$lib/paraglide/messages';
import { LucideX } from '@lucide/svelte';
let {
logoDataURL,
clientName,
resetLogo,
onLogoChange
}: {
logoDataURL: string | null;
clientName: string;
resetLogo: () => void;
onLogoChange: (file: File | string | null) => void;
} = $props();
</script>
<Label for="logo">{m.logo()}</Label>
<div class="flex items-end gap-4">
{#if logoDataURL}
<div class="flex items-start gap-4">
<div class="relative shrink-0">
<ImageBox class="size-24" src={logoDataURL} alt={m.name_logo({ name: clientName })} />
<Button
variant="destructive"
size="icon"
onclick={resetLogo}
class="absolute -top-2 -right-2 size-6 rounded-full shadow-md"
>
<LucideX class="size-3" />
</Button>
</div>
</div>
{/if}
<div class="flex flex-col gap-3">
<div class="flex flex-wrap items-center gap-2">
<UrlFileInput label={m.upload_logo()} accept="image/*" onchange={onLogoChange} />
</div>
</div>
</div>

View File

@@ -71,16 +71,21 @@
? item.allowedUserGroupsCount
: m.unrestricted()}</Table.Cell
>
<Table.Cell class="flex justify-end gap-1">
<Table.Cell class="align-middle">
<div class="flex justify-end gap-1">
<Button
href="/settings/admin/oidc-clients/{item.id}"
size="sm"
variant="outline"
aria-label={m.edit()}><LucidePencil class="size-3 " /></Button
>
<Button onclick={() => deleteClient(item)} size="sm" variant="outline" aria-label={m.delete()}
><LucideTrash class="size-3 text-red-500" /></Button
<Button
onclick={() => deleteClient(item)}
size="sm"
variant="outline"
aria-label={m.delete()}><LucideTrash class="size-3 text-red-500" /></Button
>
</div>
</Table.Cell>
{/snippet}
</AdvancedTable>

View File

@@ -38,7 +38,7 @@
<div class="flex gap-3">
<div class="aspect-square h-[56px]">
<ImageBox
class="grow rounded-lg object-contain"
class="h-8 w-8 grow rounded-lg object-contain"
src={client.hasLogo
? cachedOidcClientLogo.getUrl(client.id)
: cachedApplicationLogo.getUrl(isLightMode)}