mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-17 01:11:38 +03:00
refactor: fix code smells
This commit is contained in:
@@ -16,6 +16,10 @@ linters:
|
|||||||
presets:
|
presets:
|
||||||
- bugs
|
- bugs
|
||||||
- sql
|
- sql
|
||||||
|
exclusions:
|
||||||
|
paths:
|
||||||
|
- internal/service/test_service.go
|
||||||
|
|
||||||
run:
|
run:
|
||||||
timeout: "5m"
|
timeout: "5m"
|
||||||
tests: true
|
tests: true
|
||||||
|
|||||||
@@ -92,7 +92,10 @@ func loadKeyPEM(path string) (jwk.Key, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||||
}
|
}
|
||||||
key.Set(jwk.KeyIDKey, keyId)
|
err = key.Set(jwk.KeyIDKey, keyId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set key ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Populate other required fields
|
// Populate other required fields
|
||||||
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
|
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
|
||||||
|
|||||||
@@ -101,25 +101,25 @@ func TestLoadKeyPEM(t *testing.T) {
|
|||||||
// Check key ID is set
|
// Check key ID is set
|
||||||
var keyID string
|
var keyID string
|
||||||
err = key.Get(jwk.KeyIDKey, &keyID)
|
err = key.Get(jwk.KeyIDKey, &keyID)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, keyID)
|
assert.NotEmpty(t, keyID)
|
||||||
|
|
||||||
// Check algorithm is set
|
// Check algorithm is set
|
||||||
var alg jwa.SignatureAlgorithm
|
var alg jwa.SignatureAlgorithm
|
||||||
err = key.Get(jwk.AlgorithmKey, &alg)
|
err = key.Get(jwk.AlgorithmKey, &alg)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, alg)
|
assert.NotEmpty(t, alg)
|
||||||
|
|
||||||
// Check key usage is set
|
// Check key usage is set
|
||||||
var keyUsage string
|
var keyUsage string
|
||||||
err = key.Get(jwk.KeyUsageKey, &keyUsage)
|
err = key.Get(jwk.KeyUsageKey, &keyUsage)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, service.KeyUsageSigning, keyUsage)
|
assert.Equal(t, service.KeyUsageSigning, keyUsage)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("file not found", func(t *testing.T) {
|
t.Run("file not found", func(t *testing.T) {
|
||||||
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
|
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Nil(t, key)
|
assert.Nil(t, key)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ func TestLoadKeyPEM(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
key, err := loadKeyPEM(invalidPath)
|
key, err := loadKeyPEM(invalidPath)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Nil(t, key)
|
assert.Nil(t, key)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,19 +49,19 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
|
|||||||
|
|
||||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||||
if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||||
ctx.Error(err)
|
_ = ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest)
|
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
_ = ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiKeysDto []dto.ApiKeyDto
|
var apiKeysDto []dto.ApiKeyDto
|
||||||
if err := dto.MapStructList(apiKeys, &apiKeysDto); err != nil {
|
if err := dto.MapStructList(apiKeys, &apiKeysDto); err != nil {
|
||||||
ctx.Error(err)
|
_ = ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,19 +83,19 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
|
|||||||
|
|
||||||
var input dto.ApiKeyCreateDto
|
var input dto.ApiKeyCreateDto
|
||||||
if err := ctx.ShouldBindJSON(&input); err != nil {
|
if err := ctx.ShouldBindJSON(&input); err != nil {
|
||||||
ctx.Error(err)
|
_ = ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input)
|
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err)
|
_ = ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiKeyDto dto.ApiKeyDto
|
var apiKeyDto dto.ApiKeyDto
|
||||||
if err := dto.MapStruct(apiKey, &apiKeyDto); err != nil {
|
if err := dto.MapStruct(apiKey, &apiKeyDto); err != nil {
|
||||||
ctx.Error(err)
|
_ = ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
|
|||||||
apiKeyID := ctx.Param("id")
|
apiKeyID := ctx.Param("id")
|
||||||
|
|
||||||
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil {
|
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil {
|
||||||
ctx.Error(err)
|
_ = ctx.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -219,19 +219,6 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, claims)
|
c.JSON(http.StatusOK, claims)
|
||||||
}
|
}
|
||||||
|
|
||||||
// userInfoHandler godoc (POST method)
|
|
||||||
// @Summary Get user information (POST method)
|
|
||||||
// @Description Get user information based on the access token using POST
|
|
||||||
// @Tags OIDC
|
|
||||||
// @Accept json
|
|
||||||
// @Produce json
|
|
||||||
// @Success 200 {object} object "User claims based on requested scopes"
|
|
||||||
// @Security OAuth2AccessToken
|
|
||||||
// @Router /api/oidc/userinfo [post]
|
|
||||||
func (oc *OidcController) userInfoHandlerPost(c *gin.Context) {
|
|
||||||
// Implementation is the same as GET
|
|
||||||
}
|
|
||||||
|
|
||||||
// EndSessionHandler godoc
|
// EndSessionHandler godoc
|
||||||
// @Summary End OIDC session
|
// @Summary End OIDC session
|
||||||
// @Description End user session and handle OIDC logout
|
// @Description End user session and handle OIDC logout
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ func RegisterDbCleanupJobs(db *gorm.DB) {
|
|||||||
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
|
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
|
||||||
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
|
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
|
||||||
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
|
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
|
||||||
|
registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
|
||||||
scheduler.Start()
|
scheduler.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres
|
// DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres
|
||||||
type DateTime time.Time
|
type DateTime time.Time //nolint:recvcheck
|
||||||
|
|
||||||
func (date *DateTime) Scan(value interface{}) (err error) {
|
func (date *DateTime) Scan(value interface{}) (err error) {
|
||||||
*date = DateTime(value.(time.Time))
|
*date = DateTime(value.(time.Time))
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/emersion/go-sasl"
|
"github.com/emersion/go-sasl"
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
||||||
@@ -16,10 +17,9 @@ import (
|
|||||||
"mime/quotedprintable"
|
"mime/quotedprintable"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
ttemplate "text/template"
|
ttemplate "text/template"
|
||||||
"time"
|
"time"
|
||||||
"github.com/google/uuid"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type EmailService struct {
|
type EmailService struct {
|
||||||
@@ -107,7 +107,7 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
|
|||||||
domain = hostname
|
domain = hostname
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.AddHeader("Message-ID", "<" + uuid.New().String() + "@" + domain + ">")
|
c.AddHeader("Message-ID", "<"+uuid.New().String()+"@"+domain+">")
|
||||||
|
|
||||||
c.Body(body)
|
c.Body(body)
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
|
|||||||
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
|
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.Value == "true",
|
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
|
||||||
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value,
|
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -124,8 +125,15 @@ func (s *GeoLiteService) updateDatabase() error {
|
|||||||
log.Println("Updating GeoLite2 City database...")
|
log.Println("Updating GeoLite2 City database...")
|
||||||
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
|
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
|
||||||
|
|
||||||
// Download the database tar.gz file nolint:gosec
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||||
resp, err := http.Get(downloadUrl)
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to download database: %w", err)
|
return fmt.Errorf("failed to download database: %w", err)
|
||||||
}
|
}
|
||||||
@@ -164,6 +172,9 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
|
|||||||
|
|
||||||
tarReader := tar.NewReader(gzr)
|
tarReader := tar.NewReader(gzr)
|
||||||
|
|
||||||
|
var totalSize int64
|
||||||
|
const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size
|
||||||
|
|
||||||
// Iterate over the files in the tar archive
|
// Iterate over the files in the tar archive
|
||||||
for {
|
for {
|
||||||
header, err := tarReader.Next()
|
header, err := tarReader.Next()
|
||||||
@@ -176,6 +187,11 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
|
|||||||
|
|
||||||
// Check if the file is the GeoLite2-City.mmdb file
|
// Check if the file is the GeoLite2-City.mmdb file
|
||||||
if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" {
|
if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" {
|
||||||
|
totalSize += header.Size
|
||||||
|
if totalSize > maxTotalSize {
|
||||||
|
return errors.New("total decompressed size exceeds maximum allowed limit")
|
||||||
|
}
|
||||||
|
|
||||||
// extract to a temporary file to avoid having a corrupted db in case of write failure.
|
// extract to a temporary file to avoid having a corrupted db in case of write failure.
|
||||||
baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath)
|
baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath)
|
||||||
tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp")
|
tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp")
|
||||||
@@ -185,7 +201,7 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
|
|||||||
tempName := tmpFile.Name()
|
tempName := tmpFile.Name()
|
||||||
|
|
||||||
// Write the file contents directly to the target location
|
// Write the file contents directly to the target location
|
||||||
if _, err := io.Copy(tmpFile, tarReader); err != nil {
|
if _, err := io.Copy(tmpFile, tarReader); err != nil { //nolint:gosec
|
||||||
// if fails to write, then cleanup and throw an error
|
// if fails to write, then cleanup and throw an error
|
||||||
tmpFile.Close()
|
tmpFile.Close()
|
||||||
os.Remove(tempName)
|
os.Remove(tempName)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
// Verify the key has been saved to disk as JWK
|
// Verify the key has been saved to disk as JWK
|
||||||
jwkPath := filepath.Join(tempDir, PrivateKeyFile)
|
jwkPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||||
_, err = os.Stat(jwkPath)
|
_, err = os.Stat(jwkPath)
|
||||||
assert.NoError(t, err, "JWK file should exist")
|
require.NoError(t, err, "JWK file should exist")
|
||||||
|
|
||||||
// Verify the generated key is valid
|
// Verify the generated key is valid
|
||||||
keyData, err := os.ReadFile(jwkPath)
|
keyData, err := os.ReadFile(jwkPath)
|
||||||
@@ -229,7 +229,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Check the claims
|
// Check the claims
|
||||||
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
|
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
|
||||||
assert.Equal(t, false, claims.IsAdmin, "IsAdmin should be false")
|
assert.False(t, claims.IsAdmin, "IsAdmin should be false")
|
||||||
assert.Contains(t, claims.Audience, "https://test.example.com", "Audience should contain the app URL")
|
assert.Contains(t, claims.Audience, "https://test.example.com", "Audience should contain the app URL")
|
||||||
|
|
||||||
// Check token expiration time is approximately 60 minutes from now
|
// Check token expiration time is approximately 60 minutes from now
|
||||||
@@ -263,7 +263,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to verify generated token")
|
require.NoError(t, err, "Failed to verify generated token")
|
||||||
|
|
||||||
// Check the IsAdmin claim is true
|
// Check the IsAdmin claim is true
|
||||||
assert.Equal(t, true, claims.IsAdmin, "IsAdmin should be true for admin users")
|
assert.True(t, claims.IsAdmin, "IsAdmin should be true for admin users")
|
||||||
assert.Equal(t, adminUser.ID, claims.Subject, "Token subject should match admin ID")
|
assert.Equal(t, adminUser.ID, claims.Subject, "Token subject should match admin ID")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -404,7 +404,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
|
|
||||||
// Verify should fail due to issuer mismatch
|
// Verify should fail due to issuer mismatch
|
||||||
_, err = service.VerifyIdToken(tokenString)
|
_, err = service.VerifyIdToken(tokenString)
|
||||||
assert.Error(t, err, "Verification should fail with incorrect issuer")
|
require.Error(t, err, "Verification should fail with incorrect issuer")
|
||||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -492,7 +492,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Verify should fail due to expiration
|
// Verify should fail due to expiration
|
||||||
_, err = service.VerifyOauthAccessToken(string(signed))
|
_, err = service.VerifyOauthAccessToken(string(signed))
|
||||||
assert.Error(t, err, "Verification should fail with expired token")
|
require.Error(t, err, "Verification should fail with expired token")
|
||||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -520,7 +520,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Verify with the second service should fail due to different keys
|
// Verify with the second service should fail due to different keys
|
||||||
_, err = service2.VerifyOauthAccessToken(tokenString)
|
_, err = service2.VerifyOauthAccessToken(tokenString)
|
||||||
assert.Error(t, err, "Verification should fail with invalid signature")
|
require.Error(t, err, "Verification should fail with invalid signature")
|
||||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-ldap/ldap/v3"
|
"github.com/go-ldap/ldap/v3"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
@@ -36,7 +38,7 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
|||||||
// Setup LDAP connection
|
// Setup LDAP connection
|
||||||
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
|
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
|
||||||
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true"
|
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true"
|
||||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify}))
|
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
||||||
}
|
}
|
||||||
@@ -65,6 +67,7 @@ func (s *LdapService) SyncAll() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gocognit
|
||||||
func (s *LdapService) SyncGroups() error {
|
func (s *LdapService) SyncGroups() error {
|
||||||
// Setup LDAP connection
|
// Setup LDAP connection
|
||||||
client, err := s.createClient()
|
client, err := s.createClient()
|
||||||
@@ -150,6 +153,9 @@ func (s *LdapService) SyncGroups() error {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
|
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||||
|
}
|
||||||
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
|
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||||
@@ -180,6 +186,7 @@ func (s *LdapService) SyncGroups() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gocognit
|
||||||
func (s *LdapService) SyncUsers() error {
|
func (s *LdapService) SyncUsers() error {
|
||||||
// Setup LDAP connection
|
// Setup LDAP connection
|
||||||
client, err := s.createClient()
|
client, err := s.createClient()
|
||||||
@@ -296,8 +303,15 @@ func (s *LdapService) SaveProfilePicture(userId string, pictureString string) er
|
|||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
|
|
||||||
if _, err := url.ParseRequestURI(pictureString); err == nil {
|
if _, err := url.ParseRequestURI(pictureString); err == nil {
|
||||||
// If the photo is a URL, download it
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
response, err := http.Get(pictureString)
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to download profile picture: %w", err)
|
return fmt.Errorf("failed to download profile picture: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,6 +209,9 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
|
|||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
s.db.Delete(&authorizationCodeMetaData)
|
s.db.Delete(&authorizationCodeMetaData)
|
||||||
|
|
||||||
|
|||||||
@@ -47,16 +47,14 @@ func TestFormatAAGUID(t *testing.T) {
|
|||||||
func TestGetAuthenticatorName(t *testing.T) {
|
func TestGetAuthenticatorName(t *testing.T) {
|
||||||
// Reset the aaguidMap for testing
|
// Reset the aaguidMap for testing
|
||||||
originalMap := aaguidMap
|
originalMap := aaguidMap
|
||||||
originalOnce := aaguidMapOnce
|
|
||||||
defer func() {
|
defer func() {
|
||||||
aaguidMap = originalMap
|
aaguidMap = originalMap
|
||||||
aaguidMapOnce = originalOnce
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Inject a test AAGUID map
|
// Inject a test AAGUID map
|
||||||
aaguidMap = map[string]string{
|
aaguidMap = map[string]string{
|
||||||
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
|
"adce0002-35bc-c60a-648b-m0b25f1f05503": "Test Authenticator",
|
||||||
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
|
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
|
||||||
}
|
}
|
||||||
aaguidMapOnce = sync.Once{}
|
aaguidMapOnce = sync.Once{}
|
||||||
aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file
|
aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
|
|||||||
go func() {
|
go func() {
|
||||||
err = imaging.Encode(pw, img, imaging.PNG)
|
err = imaging.Encode(pw, img, imaging.PNG)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %v", err))
|
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pw.Close()
|
pw.Close()
|
||||||
|
|||||||
BIN
backend/main
Executable file
BIN
backend/main
Executable file
Binary file not shown.
Reference in New Issue
Block a user