feat: add CSP header (#908)

Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Elias Schneider
2025-09-07 20:45:06 +02:00
committed by GitHub
parent 74b39e16f9
commit 6215e1ac01
9 changed files with 343 additions and 102 deletions

View File

@@ -3,8 +3,10 @@
package frontend package frontend
import ( import (
"bytes"
"embed" "embed"
"fmt" "fmt"
"io"
"io/fs" "io/fs"
"net/http" "net/http"
"os" "os"
@@ -12,11 +14,55 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
) )
//go:embed all:dist/* //go:embed all:dist/*
var frontendFS embed.FS var frontendFS embed.FS
// This function, created by the init() method, writes to "w" the index.html page, populating the nonce
var writeIndexFn func(w io.Writer, nonce string) error
func init() {
const scriptTag = "<script>"
// Read the index.html from the bundle
index, iErr := fs.ReadFile(frontendFS, "dist/index.html")
if iErr != nil {
panic(fmt.Errorf("failed to read index.html: %w", iErr))
}
// Get the position of the first <script> tag
idx := bytes.Index(index, []byte(scriptTag))
// Create writeIndexFn, which adds the CSP tag to the script tag if needed
writeIndexFn = func(w io.Writer, nonce string) (err error) {
// If there's no nonce, write the index as-is
if nonce == "" {
_, err = w.Write(index)
return err
}
// We have a nonce, so first write the index until the <script> tag
// Then we write the modified script tag
// Finally, the rest of the index
_, err = w.Write(index[0:idx])
if err != nil {
return err
}
_, err = w.Write([]byte(`<script nonce="` + nonce + `">`))
if err != nil {
return err
}
_, err = w.Write(index[(idx + len(scriptTag)):])
if err != nil {
return err
}
return nil
}
}
func RegisterFrontend(router *gin.Engine) error { func RegisterFrontend(router *gin.Engine) error {
distFS, err := fs.Sub(frontendFS, "dist") distFS, err := fs.Sub(frontendFS, "dist")
if err != nil { if err != nil {
@@ -27,13 +73,39 @@ func RegisterFrontend(router *gin.Engine) error {
fileServer := NewFileServerWithCaching(http.FS(distFS), int(cacheMaxAge.Seconds())) fileServer := NewFileServerWithCaching(http.FS(distFS), int(cacheMaxAge.Seconds()))
router.NoRoute(func(c *gin.Context) { router.NoRoute(func(c *gin.Context) {
// Try to serve the requested file
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
if _, err := fs.Stat(distFS, path); os.IsNotExist(err) {
// File doesn't exist, serve index.html instead if strings.HasPrefix(path, "api/") {
c.Request.URL.Path = "/" c.JSON(http.StatusNotFound, gin.H{"error": "API endpoint not found"})
return
} }
// If path is / or does not exist, serve index.html
if path == "" {
path = "index.html"
} else if _, err := fs.Stat(distFS, path); os.IsNotExist(err) {
path = "index.html"
}
if path == "index.html" {
nonce := middleware.GetCSPNonce(c)
// Do not cache the HTML shell, as it embeds a per-request nonce
c.Header("Content-Type", "text/html; charset=utf-8")
c.Header("Cache-Control", "no-store")
c.Status(http.StatusOK)
err = writeIndexFn(c.Writer, nonce)
if err != nil {
_ = c.Error(fmt.Errorf("failed to write index.html file: %w", err))
return
}
return
}
// Serve other static assets with caching
c.Request.URL.Path = "/" + path
fileServer.ServeHTTP(c.Writer, c.Request) fileServer.ServeHTTP(c.Writer, c.Request)
}) })

View File

@@ -86,6 +86,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
// Setup global middleware // Setup global middleware
r.Use(middleware.NewCorsMiddleware().Add()) r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewCspMiddleware().Add())
r.Use(middleware.NewErrorHandlerMiddleware().Add()) r.Use(middleware.NewErrorHandlerMiddleware().Add())
err := frontend.RegisterFrontend(r) err := frontend.RegisterFrontend(r)
@@ -109,6 +110,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
controller.NewAuditLogController(apiGroup, svc.auditLogService, authMiddleware) controller.NewAuditLogController(apiGroup, svc.auditLogService, authMiddleware)
controller.NewUserGroupController(apiGroup, authMiddleware, svc.userGroupService) controller.NewUserGroupController(apiGroup, authMiddleware, svc.userGroupService)
controller.NewCustomClaimController(apiGroup, authMiddleware, svc.customClaimService) controller.NewCustomClaimController(apiGroup, authMiddleware, svc.customClaimService)
controller.NewVersionController(apiGroup, svc.versionService)
// Add test controller in non-production environments // Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" { if common.EnvConfig.AppEnv != "production" {

View File

@@ -23,6 +23,7 @@ type services struct {
userGroupService *service.UserGroupService userGroupService *service.UserGroupService
ldapService *service.LdapService ldapService *service.LdapService
apiKeyService *service.ApiKeyService apiKeyService *service.ApiKeyService
versionService *service.VersionService
} }
// Initializes all services // Initializes all services
@@ -62,5 +63,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService) svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService) svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
svc.versionService = service.NewVersionService(httpClient)
return svc, nil return svc, nil
} }

View File

@@ -0,0 +1,40 @@
package controller
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
// NewVersionController registers version-related routes.
func NewVersionController(group *gin.RouterGroup, versionService *service.VersionService) {
vc := &VersionController{versionService: versionService}
group.GET("/version/latest", vc.getLatestVersionHandler)
}
type VersionController struct {
versionService *service.VersionService
}
// getLatestVersionHandler godoc
// @Summary Get latest available version of Pocket ID
// @Tags Version
// @Produce json
// @Success 200 {object} map[string]string "Latest version information"
// @Router /api/version/latest [get]
func (vc *VersionController) getLatestVersionHandler(c *gin.Context) {
tag, err := vc.versionService.GetLatestVersion(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
}
utils.SetCacheControlHeader(c, 5*time.Minute, 15*time.Minute)
c.JSON(http.StatusOK, gin.H{
"latestVersion": tag,
})
}

View File

@@ -0,0 +1,53 @@
package middleware
import (
"crypto/rand"
"encoding/base64"
"github.com/gin-gonic/gin"
)
// CspMiddleware sets a Content Security Policy header and, when possible,
// includes a per-request nonce for inline scripts.
type CspMiddleware struct{}
func NewCspMiddleware() *CspMiddleware { return &CspMiddleware{} }
// GetCSPNonce returns the CSP nonce generated for this request, if any.
func GetCSPNonce(c *gin.Context) string {
if v, ok := c.Get("csp_nonce"); ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func (m *CspMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) {
// Generate a random base64 nonce for this request
nonce := generateNonce()
c.Set("csp_nonce", nonce)
csp := "default-src 'self'; " +
"base-uri 'self'; " +
"object-src 'none'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"img-src 'self' data: blob:; " +
"font-src 'self'; " +
"style-src 'self' 'unsafe-inline'; " +
"script-src 'self' 'nonce-" + nonce + "'"
c.Writer.Header().Set("Content-Security-Policy", csp)
c.Next()
}
}
func generateNonce() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "" // if generation fails, return empty; policy will omit nonce
}
return base64.RawURLEncoding.EncodeToString(b)
}

View File

@@ -0,0 +1,74 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
const (
versionTTL = 15 * time.Minute
versionCheckURL = "https://api.github.com/repos/pocket-id/pocket-id/releases/latest"
)
type VersionService struct {
httpClient *http.Client
cache *utils.Cache[string]
}
func NewVersionService(httpClient *http.Client) *VersionService {
return &VersionService{
httpClient: httpClient,
cache: utils.New[string](versionTTL),
}
}
func (s *VersionService) GetLatestVersion(ctx context.Context) (string, error) {
version, err := s.cache.GetOrFetch(ctx, func(ctx context.Context) (string, error) {
reqCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, versionCheckURL, nil)
if err != nil {
return "", fmt.Errorf("create GitHub request: %w", err)
}
resp, err := s.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("get latest tag: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GitHub API returned status %d", resp.StatusCode)
}
var payload struct {
TagName string `json:"tag_name"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return "", fmt.Errorf("decode payload: %w", err)
}
if payload.TagName == "" {
return "", fmt.Errorf("GitHub API returned empty tag name")
}
return strings.TrimPrefix(payload.TagName, "v"), nil
})
var staleErr *utils.ErrStale
if errors.As(err, &staleErr) {
slog.Warn("Failed to fetch latest version, returning stale cache", "error", staleErr.Err)
return version, nil
}
return version, err
}

View File

@@ -0,0 +1,78 @@
package utils
import (
"context"
"sync/atomic"
"time"
"golang.org/x/sync/singleflight"
)
type CacheEntry[T any] struct {
Value T
FetchedAt time.Time
}
type ErrStale struct {
Err error
}
func (e *ErrStale) Error() string { return "returned stale cache: " + e.Err.Error() }
func (e *ErrStale) Unwrap() error { return e.Err }
type Cache[T any] struct {
ttl time.Duration
entry atomic.Pointer[CacheEntry[T]]
sf singleflight.Group
}
func New[T any](ttl time.Duration) *Cache[T] {
return &Cache[T]{ttl: ttl}
}
// Get returns the cached value if it's still fresh.
func (c *Cache[T]) Get() (T, bool) {
entry := c.entry.Load()
if entry == nil {
var zero T
return zero, false
}
if time.Since(entry.FetchedAt) < c.ttl {
return entry.Value, true
}
var zero T
return zero, false
}
// GetOrFetch returns the cached value if it's still fresh, otherwise calls fetch to get a new value.
func (c *Cache[T]) GetOrFetch(ctx context.Context, fetch func(context.Context) (T, error)) (T, error) {
// If fresh, serve immediately
if v, ok := c.Get(); ok {
return v, nil
}
// Fetch with singleflight to prevent multiple concurrent fetches
vAny, err, _ := c.sf.Do("singleton", func() (any, error) {
if v2, ok := c.Get(); ok {
return v2, nil
}
val, fetchErr := fetch(ctx)
if fetchErr != nil {
return nil, fetchErr
}
c.entry.Store(&CacheEntry[T]{Value: val, FetchedAt: time.Now()})
return val, nil
})
if err == nil {
return vAny.(T), nil
}
// Fetch failed. Return stale if possible.
if e := c.entry.Load(); e != nil {
return e.Value, &ErrStale{Err: err}
}
var zero T
return zero, err
}

View File

@@ -1,109 +1,21 @@
import { version as currentVersion } from '$app/environment'; import { version as currentVersion } from '$app/environment';
import axios from 'axios'; import axios from 'axios';
const VERSION_CACHE_KEY = 'version_cache';
const CACHE_DURATION = 2 * 60 * 60 * 1000; // 2 hours
async function getNewestVersion() { async function getNewestVersion() {
const cachedData = await getVersionFromCache(); const response = await axios
.get('/api/version/latest', {
timeout: 2000
})
.then((res) => res.data);
// If we have valid cached data, return it return response.latestVersion;
if (cachedData) {
return cachedData;
}
// Otherwise fetch from API
try {
const response = await axios
.get('https://api.github.com/repos/pocket-id/pocket-id/releases/latest', {
timeout: 2000
})
.then((res) => res.data);
console.log('Fetched newest version:', response);
const newestVersion = response.tag_name.replace('v', '');
// Cache the result
cacheVersion(newestVersion);
return newestVersion;
} catch (error) {
console.error('Failed to fetch newest version:', error);
// If fetch fails but we have an expired cache, return that as fallback
const cache = getCacheObject();
return cache?.newestVersion || currentVersion;
}
} }
function getCurrentVersion() { function getCurrentVersion() {
return currentVersion; return currentVersion;
} }
async function isUpToDate() {
const newestVersion = await getNewestVersion();
const currentVersion = getCurrentVersion();
// If the current version changed, invalidate the cache
const cache = getCacheObject();
if (cache?.lastCurrentVersion && currentVersion !== cache.lastCurrentVersion) {
invalidateCache();
}
return newestVersion === currentVersion;
}
// Helper methods for caching
function getCacheObject() {
const cacheJson = localStorage.getItem(VERSION_CACHE_KEY);
if (!cacheJson) return null;
try {
return JSON.parse(cacheJson);
} catch (e) {
console.error('Failed to parse cache:', e);
return null;
}
}
async function getVersionFromCache() {
const cache = getCacheObject();
if (!cache || !cache.newestVersion || !cache.timestamp) {
return null;
}
const now = Date.now();
// Check if cache is still valid
if (now - cache.timestamp > CACHE_DURATION) {
invalidateCache();
return null;
}
// Check if current version matches what it was when we cached
if (cache.lastCurrentVersion && cache.lastCurrentVersion !== currentVersion) {
invalidateCache();
return null;
}
return cache.newestVersion;
}
async function cacheVersion(version: string) {
const cacheObject = {
newestVersion: version,
timestamp: Date.now(),
lastCurrentVersion: currentVersion
};
localStorage.setItem(VERSION_CACHE_KEY, JSON.stringify(cacheObject));
}
async function invalidateCache() {
localStorage.removeItem(VERSION_CACHE_KEY);
}
export default { export default {
getNewestVersion, getNewestVersion,
getCurrentVersion, getCurrentVersion,
isUpToDate
}; };

View File

@@ -2,13 +2,20 @@ import versionService from '$lib/services/version-service';
import type { AppVersionInformation } from '$lib/types/application-configuration'; import type { AppVersionInformation } from '$lib/types/application-configuration';
import type { LayoutLoad } from './$types'; import type { LayoutLoad } from './$types';
export const prerender = false;
export const load: LayoutLoad = async () => { export const load: LayoutLoad = async () => {
const currentVersion = versionService.getCurrentVersion();
let newestVersion = null;
let isUpToDate = true;
try {
newestVersion = await versionService.getNewestVersion();
isUpToDate = newestVersion === currentVersion;
} catch {}
const versionInformation: AppVersionInformation = { const versionInformation: AppVersionInformation = {
currentVersion: versionService.getCurrentVersion(), currentVersion: versionService.getCurrentVersion(),
newestVersion: await versionService.getNewestVersion(), newestVersion,
isUpToDate: await versionService.isUpToDate() isUpToDate
}; };
return { return {