mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 17:23:24 +03:00
refactor: fix code smells
This commit is contained in:
@@ -16,6 +16,10 @@ linters:
|
||||
presets:
|
||||
- bugs
|
||||
- sql
|
||||
exclusions:
|
||||
paths:
|
||||
- internal/service/test_service.go
|
||||
|
||||
run:
|
||||
timeout: "5m"
|
||||
tests: true
|
||||
|
||||
@@ -92,7 +92,10 @@ func loadKeyPEM(path string) (jwk.Key, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
key.Set(jwk.KeyIDKey, keyId)
|
||||
err = key.Set(jwk.KeyIDKey, keyId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set key ID: %w", err)
|
||||
}
|
||||
|
||||
// Populate other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
|
||||
|
||||
@@ -101,25 +101,25 @@ func TestLoadKeyPEM(t *testing.T) {
|
||||
// Check key ID is set
|
||||
var keyID string
|
||||
err = key.Get(jwk.KeyIDKey, &keyID)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, keyID)
|
||||
|
||||
// Check algorithm is set
|
||||
var alg jwa.SignatureAlgorithm
|
||||
err = key.Get(jwk.AlgorithmKey, &alg)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, alg)
|
||||
|
||||
// Check key usage is set
|
||||
var keyUsage string
|
||||
err = key.Get(jwk.KeyUsageKey, &keyUsage)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, service.KeyUsageSigning, keyUsage)
|
||||
})
|
||||
|
||||
t.Run("file not found", func(t *testing.T) {
|
||||
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, key)
|
||||
})
|
||||
|
||||
@@ -129,7 +129,7 @@ func TestLoadKeyPEM(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := loadKeyPEM(invalidPath)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, key)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -49,19 +49,19 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
|
||||
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
ctx.Error(err)
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var apiKeysDto []dto.ApiKeyDto
|
||||
if err := dto.MapStructList(apiKeys, &apiKeysDto); err != nil {
|
||||
ctx.Error(err)
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -83,19 +83,19 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
|
||||
|
||||
var input dto.ApiKeyCreateDto
|
||||
if err := ctx.ShouldBindJSON(&input); err != nil {
|
||||
ctx.Error(err)
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input)
|
||||
if err != nil {
|
||||
ctx.Error(err)
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var apiKeyDto dto.ApiKeyDto
|
||||
if err := dto.MapStruct(apiKey, &apiKeyDto); err != nil {
|
||||
ctx.Error(err)
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -117,7 +117,7 @@ func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
|
||||
apiKeyID := ctx.Param("id")
|
||||
|
||||
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil {
|
||||
ctx.Error(err)
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -219,19 +219,6 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, claims)
|
||||
}
|
||||
|
||||
// userInfoHandler godoc (POST method)
|
||||
// @Summary Get user information (POST method)
|
||||
// @Description Get user information based on the access token using POST
|
||||
// @Tags OIDC
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} object "User claims based on requested scopes"
|
||||
// @Security OAuth2AccessToken
|
||||
// @Router /api/oidc/userinfo [post]
|
||||
func (oc *OidcController) userInfoHandlerPost(c *gin.Context) {
|
||||
// Implementation is the same as GET
|
||||
}
|
||||
|
||||
// EndSessionHandler godoc
|
||||
// @Summary End OIDC session
|
||||
// @Description End user session and handle OIDC logout
|
||||
|
||||
@@ -23,6 +23,7 @@ func RegisterDbCleanupJobs(db *gorm.DB) {
|
||||
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
|
||||
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
|
||||
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
|
||||
registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
|
||||
scheduler.Start()
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres
|
||||
type DateTime time.Time
|
||||
type DateTime time.Time //nolint:recvcheck
|
||||
|
||||
func (date *DateTime) Scan(value interface{}) (err error) {
|
||||
*date = DateTime(value.(time.Time))
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
||||
@@ -16,10 +17,9 @@ import (
|
||||
"mime/quotedprintable"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"strings"
|
||||
ttemplate "text/template"
|
||||
"time"
|
||||
"github.com/google/uuid"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EmailService struct {
|
||||
@@ -107,7 +107,7 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
|
||||
domain = hostname
|
||||
}
|
||||
}
|
||||
c.AddHeader("Message-ID", "<" + uuid.New().String() + "@" + domain + ">")
|
||||
c.AddHeader("Message-ID", "<"+uuid.New().String()+"@"+domain+">")
|
||||
|
||||
c.Body(body)
|
||||
|
||||
@@ -131,7 +131,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
|
||||
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.Value == "true",
|
||||
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
|
||||
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -124,8 +125,15 @@ func (s *GeoLiteService) updateDatabase() error {
|
||||
log.Println("Updating GeoLite2 City database...")
|
||||
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
|
||||
|
||||
// Download the database tar.gz file nolint:gosec
|
||||
resp, err := http.Get(downloadUrl)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download database: %w", err)
|
||||
}
|
||||
@@ -164,6 +172,9 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
|
||||
|
||||
tarReader := tar.NewReader(gzr)
|
||||
|
||||
var totalSize int64
|
||||
const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size
|
||||
|
||||
// Iterate over the files in the tar archive
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
@@ -176,6 +187,11 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
|
||||
|
||||
// Check if the file is the GeoLite2-City.mmdb file
|
||||
if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" {
|
||||
totalSize += header.Size
|
||||
if totalSize > maxTotalSize {
|
||||
return errors.New("total decompressed size exceeds maximum allowed limit")
|
||||
}
|
||||
|
||||
// extract to a temporary file to avoid having a corrupted db in case of write failure.
|
||||
baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath)
|
||||
tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp")
|
||||
@@ -185,7 +201,7 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
|
||||
tempName := tmpFile.Name()
|
||||
|
||||
// Write the file contents directly to the target location
|
||||
if _, err := io.Copy(tmpFile, tarReader); err != nil {
|
||||
if _, err := io.Copy(tmpFile, tarReader); err != nil { //nolint:gosec
|
||||
// if fails to write, then cleanup and throw an error
|
||||
tmpFile.Close()
|
||||
os.Remove(tempName)
|
||||
|
||||
@@ -38,7 +38,7 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Verify the key has been saved to disk as JWK
|
||||
jwkPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
_, err = os.Stat(jwkPath)
|
||||
assert.NoError(t, err, "JWK file should exist")
|
||||
require.NoError(t, err, "JWK file should exist")
|
||||
|
||||
// Verify the generated key is valid
|
||||
keyData, err := os.ReadFile(jwkPath)
|
||||
@@ -229,7 +229,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
|
||||
// Check the claims
|
||||
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
|
||||
assert.Equal(t, false, claims.IsAdmin, "IsAdmin should be false")
|
||||
assert.False(t, claims.IsAdmin, "IsAdmin should be false")
|
||||
assert.Contains(t, claims.Audience, "https://test.example.com", "Audience should contain the app URL")
|
||||
|
||||
// Check token expiration time is approximately 60 minutes from now
|
||||
@@ -263,7 +263,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to verify generated token")
|
||||
|
||||
// Check the IsAdmin claim is true
|
||||
assert.Equal(t, true, claims.IsAdmin, "IsAdmin should be true for admin users")
|
||||
assert.True(t, claims.IsAdmin, "IsAdmin should be true for admin users")
|
||||
assert.Equal(t, adminUser.ID, claims.Subject, "Token subject should match admin ID")
|
||||
})
|
||||
|
||||
@@ -404,7 +404,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
|
||||
// Verify should fail due to issuer mismatch
|
||||
_, err = service.VerifyIdToken(tokenString)
|
||||
assert.Error(t, err, "Verification should fail with incorrect issuer")
|
||||
require.Error(t, err, "Verification should fail with incorrect issuer")
|
||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
||||
})
|
||||
}
|
||||
@@ -492,7 +492,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
|
||||
// Verify should fail due to expiration
|
||||
_, err = service.VerifyOauthAccessToken(string(signed))
|
||||
assert.Error(t, err, "Verification should fail with expired token")
|
||||
require.Error(t, err, "Verification should fail with expired token")
|
||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
||||
})
|
||||
|
||||
@@ -520,7 +520,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
|
||||
// Verify with the second service should fail due to different keys
|
||||
_, err = service2.VerifyOauthAccessToken(tokenString)
|
||||
assert.Error(t, err, "Verification should fail with invalid signature")
|
||||
require.Error(t, err, "Verification should fail with invalid signature")
|
||||
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
@@ -36,7 +38,7 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
// Setup LDAP connection
|
||||
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
|
||||
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true"
|
||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify}))
|
||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
||||
}
|
||||
@@ -65,6 +67,7 @@ func (s *LdapService) SyncAll() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncGroups() error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
@@ -150,6 +153,9 @@ func (s *LdapService) SyncGroups() error {
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
}
|
||||
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
@@ -180,6 +186,7 @@ func (s *LdapService) SyncGroups() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncUsers() error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
@@ -296,8 +303,15 @@ func (s *LdapService) SaveProfilePicture(userId string, pictureString string) er
|
||||
var reader io.Reader
|
||||
|
||||
if _, err := url.ParseRequestURI(pictureString); err == nil {
|
||||
// If the photo is a URL, download it
|
||||
response, err := http.Get(pictureString)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download profile picture: %w", err)
|
||||
}
|
||||
|
||||
@@ -209,6 +209,9 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
|
||||
}
|
||||
|
||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
s.db.Delete(&authorizationCodeMetaData)
|
||||
|
||||
|
||||
@@ -47,15 +47,13 @@ func TestFormatAAGUID(t *testing.T) {
|
||||
func TestGetAuthenticatorName(t *testing.T) {
|
||||
// Reset the aaguidMap for testing
|
||||
originalMap := aaguidMap
|
||||
originalOnce := aaguidMapOnce
|
||||
defer func() {
|
||||
aaguidMap = originalMap
|
||||
aaguidMapOnce = originalOnce
|
||||
}()
|
||||
|
||||
// Inject a test AAGUID map
|
||||
aaguidMap = map[string]string{
|
||||
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
|
||||
"adce0002-35bc-c60a-648b-m0b25f1f05503": "Test Authenticator",
|
||||
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
|
||||
}
|
||||
aaguidMapOnce = sync.Once{}
|
||||
|
||||
@@ -32,7 +32,7 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
|
||||
go func() {
|
||||
err = imaging.Encode(pw, img, imaging.PNG)
|
||||
if err != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %v", err))
|
||||
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", err))
|
||||
return
|
||||
}
|
||||
pw.Close()
|
||||
|
||||
BIN
backend/main
Executable file
BIN
backend/main
Executable file
Binary file not shown.
Reference in New Issue
Block a user