feat: support for url based icons (#840)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-09-29 10:07:55 -05:00
committed by GitHub
parent 47bd5ba1ba
commit 6bdf5fa37a
19 changed files with 650 additions and 442 deletions

View File

@@ -8,9 +8,12 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net"
"net/http"
"net/url"
"os"
"regexp"
"slices"
@@ -66,6 +69,7 @@ func NewOidcService(
auditLogService *AuditLogService,
customClaimService *CustomClaimService,
webAuthnService *WebAuthnService,
httpClient *http.Client,
) (s *OidcService, err error) {
s = &OidcService{
db: db,
@@ -74,6 +78,7 @@ func NewOidcService(
auditLogService: auditLogService,
customClaimService: customClaimService,
webAuthnService: webAuthnService,
httpClient: httpClient,
}
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
@@ -714,6 +719,11 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
}
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
client := model.OidcClient{
Base: model.Base{
ID: input.ID,
@@ -722,7 +732,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
}
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
err := s.db.
err := tx.
WithContext(ctx).
Create(&client).
Error
@@ -733,33 +743,11 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient
err := tx.
WithContext(ctx).
Preload("CreatedBy").
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err
}
updateOIDCClientModelFromDto(&client, &input)
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
if input.LogoURL != nil {
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
if err != nil {
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
}
}
err = tx.Commit().Error
@@ -770,6 +758,36 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return client, nil
}
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() { tx.Rollback() }()
var client model.OidcClient
if err := tx.WithContext(ctx).
Preload("CreatedBy").
First(&client, "id = ?", clientID).Error; err != nil {
return model.OidcClient{}, err
}
updateOIDCClientModelFromDto(&client, &input)
if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
return model.OidcClient{}, err
}
if input.LogoURL != nil {
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
if err != nil {
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
}
}
if err := tx.Commit().Error; err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientUpdateDto) {
// Base fields
client.Name = input.Name
@@ -883,41 +901,14 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
}
tx := s.db.Begin()
defer func() {
err = s.updateClientLogoType(ctx, tx, clientID, fileType)
if err != nil {
tx.Rollback()
}()
var client model.OidcClient
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err
}
if client.ImageType != nil && fileType != *client.ImageType {
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
if err := os.Remove(oldImagePath); err != nil {
return err
}
}
client.ImageType = &fileType
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
err = tx.Commit().Error
if err != nil {
return err
}
return nil
return tx.Commit().Error
}
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
@@ -941,6 +932,7 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
oldImageType := *client.ImageType
client.ImageType = nil
err = tx.
WithContext(ctx).
Save(&client).
@@ -1333,7 +1325,7 @@ func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, us
Client: dto.OidcClientMetaDataDto{
ID: deviceAuth.Client.ID,
Name: deviceAuth.Client.Name,
HasLogo: deviceAuth.Client.HasLogo,
HasLogo: deviceAuth.Client.HasLogo(),
},
Scope: deviceAuth.Scope,
AuthorizationRequired: !hasAuthorizedClient,
@@ -1468,7 +1460,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
ID: client.ID,
Name: client.Name,
LaunchURL: client.LaunchURL,
HasLogo: client.HasLogo,
HasLogo: client.HasLogo(),
},
LastUsedAt: lastUsedAt,
}
@@ -1889,3 +1881,87 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
return s.IsUserGroupAllowedToAuthorize(user, client), nil
}
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string) error {
u, err := url.Parse(raw)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
defer cancel()
r := net.Resolver{}
ips, err := r.LookupIPAddr(ctx, u.Hostname())
if err != nil || len(ips) == 0 {
return fmt.Errorf("cannot resolve hostname")
}
// Prevents SSRF by allowing only public IPs
for _, addr := range ips {
if utils.IsPrivateIP(addr.IP) {
return fmt.Errorf("private IP addresses are not allowed")
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
if err != nil {
return err
}
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
req.Header.Set("Accept", "image/*")
resp, err := s.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to fetch logo: %s", resp.Status)
}
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
if resp.ContentLength > maxLogoSize {
return fmt.Errorf("logo is too large")
}
// Prefer extension in path if supported
ext := utils.GetFileExtension(u.Path)
if ext == "" || utils.GetImageMimeType(ext) == "" {
// Otherwise, try to detect from content type
ext = utils.GetImageExtensionFromMimeType(resp.Header.Get("Content-Type"))
}
if ext == "" {
return &common.FileTypeNotSupportedError{}
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + ext
err = utils.SaveFileStream(io.LimitReader(resp.Body, maxLogoSize+1), imagePath)
if err != nil {
return err
}
if err := s.updateClientLogoType(ctx, tx, clientID, ext); err != nil {
return err
}
return nil
}
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string) error {
uploadsDir := common.EnvConfig.UploadPath + "/oidc-client-images"
var client model.OidcClient
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if client.ImageType != nil && *client.ImageType != ext {
old := fmt.Sprintf("%s/%s.%s", uploadsDir, client.ID, *client.ImageType)
_ = os.Remove(old)
}
client.ImageType = &ext
return tx.WithContext(ctx).Save(&client).Error
}