refactor: fix code smells

This commit is contained in:
Elias Schneider
2025-03-27 17:46:10 +01:00
parent c9e0073b63
commit 5c198c280c
15 changed files with 74 additions and 48 deletions

View File

@@ -16,6 +16,10 @@ linters:
presets:
- bugs
- sql
exclusions:
paths:
- internal/service/test_service.go
run:
timeout: "5m"
tests: true

View File

@@ -92,7 +92,10 @@ func loadKeyPEM(path string) (jwk.Key, error) {
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
key.Set(jwk.KeyIDKey, keyId)
err = key.Set(jwk.KeyIDKey, keyId)
if err != nil {
return nil, fmt.Errorf("failed to set key ID: %w", err)
}
// Populate other required fields
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)

View File

@@ -101,25 +101,25 @@ func TestLoadKeyPEM(t *testing.T) {
// Check key ID is set
var keyID string
err = key.Get(jwk.KeyIDKey, &keyID)
assert.NoError(t, err)
require.NoError(t, err)
assert.NotEmpty(t, keyID)
// Check algorithm is set
var alg jwa.SignatureAlgorithm
err = key.Get(jwk.AlgorithmKey, &alg)
assert.NoError(t, err)
require.NoError(t, err)
assert.NotEmpty(t, alg)
// Check key usage is set
var keyUsage string
err = key.Get(jwk.KeyUsageKey, &keyUsage)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, service.KeyUsageSigning, keyUsage)
})
t.Run("file not found", func(t *testing.T) {
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
assert.Error(t, err)
require.Error(t, err)
assert.Nil(t, key)
})
@@ -129,7 +129,7 @@ func TestLoadKeyPEM(t *testing.T) {
require.NoError(t, err)
key, err := loadKeyPEM(invalidPath)
assert.Error(t, err)
require.Error(t, err)
assert.Nil(t, key)
})
}

View File

@@ -49,19 +49,19 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest
if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil {
ctx.Error(err)
_ = ctx.Error(err)
return
}
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest)
if err != nil {
ctx.Error(err)
_ = ctx.Error(err)
return
}
var apiKeysDto []dto.ApiKeyDto
if err := dto.MapStructList(apiKeys, &apiKeysDto); err != nil {
ctx.Error(err)
_ = ctx.Error(err)
return
}
@@ -83,19 +83,19 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
var input dto.ApiKeyCreateDto
if err := ctx.ShouldBindJSON(&input); err != nil {
ctx.Error(err)
_ = ctx.Error(err)
return
}
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input)
if err != nil {
ctx.Error(err)
_ = ctx.Error(err)
return
}
var apiKeyDto dto.ApiKeyDto
if err := dto.MapStruct(apiKey, &apiKeyDto); err != nil {
ctx.Error(err)
_ = ctx.Error(err)
return
}
@@ -117,7 +117,7 @@ func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
apiKeyID := ctx.Param("id")
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil {
ctx.Error(err)
_ = ctx.Error(err)
return
}

View File

@@ -219,19 +219,6 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
c.JSON(http.StatusOK, claims)
}
// userInfoHandler godoc (POST method)
// @Summary Get user information (POST method)
// @Description Get user information based on the access token using POST
// @Tags OIDC
// @Accept json
// @Produce json
// @Success 200 {object} object "User claims based on requested scopes"
// @Security OAuth2AccessToken
// @Router /api/oidc/userinfo [post]
func (oc *OidcController) userInfoHandlerPost(c *gin.Context) {
// Implementation is the same as GET
}
// EndSessionHandler godoc
// @Summary End OIDC session
// @Description End user session and handle OIDC logout

View File

@@ -23,6 +23,7 @@ func RegisterDbCleanupJobs(db *gorm.DB) {
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
scheduler.Start()
}

View File

@@ -8,7 +8,7 @@ import (
)
// DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres
type DateTime time.Time
type DateTime time.Time //nolint:recvcheck
func (date *DateTime) Scan(value interface{}) (err error) {
*date = DateTime(value.(time.Time))

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
@@ -16,10 +17,9 @@ import (
"mime/quotedprintable"
"net/textproto"
"os"
"strings"
ttemplate "text/template"
"time"
"github.com/google/uuid"
"strings"
)
type EmailService struct {
@@ -107,7 +107,7 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
domain = hostname
}
}
c.AddHeader("Message-ID", "<" + uuid.New().String() + "@" + domain + ">")
c.AddHeader("Message-ID", "<"+uuid.New().String()+"@"+domain+">")
c.Body(body)
@@ -131,7 +131,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
tlsConfig := &tls.Config{
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.Value == "true",
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value,
}

View File

@@ -3,6 +3,7 @@ package service
import (
"archive/tar"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
@@ -124,8 +125,15 @@ func (s *GeoLiteService) updateDatabase() error {
log.Println("Updating GeoLite2 City database...")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
// Download the database tar.gz file nolint:gosec
resp, err := http.Get(downloadUrl)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download database: %w", err)
}
@@ -164,6 +172,9 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tarReader := tar.NewReader(gzr)
var totalSize int64
const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size
// Iterate over the files in the tar archive
for {
header, err := tarReader.Next()
@@ -176,6 +187,11 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
// Check if the file is the GeoLite2-City.mmdb file
if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" {
totalSize += header.Size
if totalSize > maxTotalSize {
return errors.New("total decompressed size exceeds maximum allowed limit")
}
// extract to a temporary file to avoid having a corrupted db in case of write failure.
baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath)
tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp")
@@ -185,7 +201,7 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tempName := tmpFile.Name()
// Write the file contents directly to the target location
if _, err := io.Copy(tmpFile, tarReader); err != nil {
if _, err := io.Copy(tmpFile, tarReader); err != nil { //nolint:gosec
// if fails to write, then cleanup and throw an error
tmpFile.Close()
os.Remove(tempName)

View File

@@ -38,7 +38,7 @@ func TestJwtService_Init(t *testing.T) {
// Verify the key has been saved to disk as JWK
jwkPath := filepath.Join(tempDir, PrivateKeyFile)
_, err = os.Stat(jwkPath)
assert.NoError(t, err, "JWK file should exist")
require.NoError(t, err, "JWK file should exist")
// Verify the generated key is valid
keyData, err := os.ReadFile(jwkPath)
@@ -229,7 +229,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Check the claims
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
assert.Equal(t, false, claims.IsAdmin, "IsAdmin should be false")
assert.False(t, claims.IsAdmin, "IsAdmin should be false")
assert.Contains(t, claims.Audience, "https://test.example.com", "Audience should contain the app URL")
// Check token expiration time is approximately 60 minutes from now
@@ -263,7 +263,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
require.NoError(t, err, "Failed to verify generated token")
// Check the IsAdmin claim is true
assert.Equal(t, true, claims.IsAdmin, "IsAdmin should be true for admin users")
assert.True(t, claims.IsAdmin, "IsAdmin should be true for admin users")
assert.Equal(t, adminUser.ID, claims.Subject, "Token subject should match admin ID")
})
@@ -404,7 +404,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Verify should fail due to issuer mismatch
_, err = service.VerifyIdToken(tokenString)
assert.Error(t, err, "Verification should fail with incorrect issuer")
require.Error(t, err, "Verification should fail with incorrect issuer")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
}
@@ -492,7 +492,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Verify should fail due to expiration
_, err = service.VerifyOauthAccessToken(string(signed))
assert.Error(t, err, "Verification should fail with expired token")
require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
@@ -520,7 +520,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Verify with the second service should fail due to different keys
_, err = service2.VerifyOauthAccessToken(tokenString)
assert.Error(t, err, "Verification should fail with invalid signature")
require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
}

View File

@@ -2,6 +2,7 @@ package service
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
@@ -11,6 +12,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
"github.com/go-ldap/ldap/v3"
"github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -36,7 +38,7 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
// Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true"
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify}))
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
}
@@ -65,6 +67,7 @@ func (s *LdapService) SyncAll() error {
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups() error {
// Setup LDAP connection
client, err := s.createClient()
@@ -150,6 +153,9 @@ func (s *LdapService) SyncGroups() error {
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
@@ -180,6 +186,7 @@ func (s *LdapService) SyncGroups() error {
return nil
}
//nolint:gocognit
func (s *LdapService) SyncUsers() error {
// Setup LDAP connection
client, err := s.createClient()
@@ -296,8 +303,15 @@ func (s *LdapService) SaveProfilePicture(userId string, pictureString string) er
var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil {
// If the photo is a URL, download it
response, err := http.Get(pictureString)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
response, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err)
}

View File

@@ -209,6 +209,9 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
}
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
if err != nil {
return "", "", "", 0, err
}
s.db.Delete(&authorizationCodeMetaData)

View File

@@ -47,15 +47,13 @@ func TestFormatAAGUID(t *testing.T) {
func TestGetAuthenticatorName(t *testing.T) {
// Reset the aaguidMap for testing
originalMap := aaguidMap
originalOnce := aaguidMapOnce
defer func() {
aaguidMap = originalMap
aaguidMapOnce = originalOnce
}()
// Inject a test AAGUID map
aaguidMap = map[string]string{
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
"adce0002-35bc-c60a-648b-m0b25f1f05503": "Test Authenticator",
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
}
aaguidMapOnce = sync.Once{}

View File

@@ -32,7 +32,7 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
go func() {
err = imaging.Encode(pw, img, imaging.PNG)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %v", err))
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", err))
return
}
pw.Close()

BIN
backend/main Executable file

Binary file not shown.