feat: store keys as JWK on disk (#339)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-03-18 13:08:33 -07:00
committed by GitHub
parent e9b2d981b7
commit a7c9741802
10 changed files with 1207 additions and 174 deletions

View File

@@ -18,9 +18,11 @@ require (
github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3
github.com/mileusna/useragent v1.3.5
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
golang.org/x/crypto v0.35.0
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.36.0
golang.org/x/image v0.24.0
golang.org/x/time v0.9.0
gorm.io/driver/postgres v1.5.11
@@ -33,6 +35,8 @@ require (
github.com/bytedance/sonic v1.12.8 // indirect
github.com/bytedance/sonic/loader v0.2.3 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
github.com/disintegration/gift v1.1.2 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v1.0.0 // indirect
@@ -55,6 +59,10 @@ require (
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect
@@ -62,7 +70,9 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
@@ -70,9 +80,9 @@ require (
golang.org/x/arch v0.13.0 // indirect
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
golang.org/x/net v0.36.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.36.4 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -20,6 +20,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8=
github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM=
github.com/disintegration/gift v1.1.2 h1:9ZyHJr+kPamiH10FX3Pynt1AxFUob812bU9Wt4GMzhs=
@@ -137,6 +139,16 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3 h1:HHT8iW+UcPBgBr5A3soZQQsL5cBor/u6BkLB+wzY/R0=
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -176,12 +188,15 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
@@ -217,8 +232,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -249,8 +264,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -263,8 +278,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -283,8 +298,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -11,5 +11,7 @@ func Bootstrap() {
db := newDatabase()
appConfigService := service.NewAppConfigService(db)
migrateKey()
initRouter(db, appConfigService)
}

View File

@@ -0,0 +1,133 @@
package bootstrap
import (
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"fmt"
"log"
"os"
"path/filepath"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
const (
privateKeyFilePem = "jwt_private_key.pem"
)
func migrateKey() {
err := migrateKeyInternal(common.EnvConfig.KeysPath)
if err != nil {
log.Fatalf("failed to perform migration of keys: %v", err)
}
}
func migrateKeyInternal(basePath string) error {
// First, check if there's already a JWK stored
jwkPath := filepath.Join(basePath, service.PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
}
if ok {
// There's already a key as JWK, so we don't do anything else here
return nil
}
// Check if there's a PEM file
pemPath := filepath.Join(basePath, privateKeyFilePem)
ok, err = utils.FileExists(pemPath)
if err != nil {
return fmt.Errorf("failed to check if private key file (PEM) exists at path '%s': %w", pemPath, err)
}
if !ok {
// No file to migrate, return
return nil
}
// Load and validate the key
key, err := loadKeyPEM(pemPath)
if err != nil {
return fmt.Errorf("failed to load private key file (PEM) at path '%s': %w", pemPath, err)
}
err = service.ValidateKey(key)
if err != nil {
return fmt.Errorf("key object is invalid: %w", err)
}
// Save the key as JWK
err = service.SaveKeyJWK(key, jwkPath)
if err != nil {
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
// Finally, delete the PEM file
err = os.Remove(pemPath)
if err != nil {
return fmt.Errorf("failed to remove migrated key at path '%s': %w", pemPath, err)
}
return nil
}
func loadKeyPEM(path string) (jwk.Key, error) {
// Load the key from disk and parse it
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data, jwk.WithPEM(true))
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
// Populate the key ID using the "legacy" algorithm
keyId, err := generateKeyID(key)
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
key.Set(jwk.KeyIDKey, keyId)
// Populate other required fields
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
service.EnsureAlgInKey(key)
return key, nil
}
// generateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key's PKIX-serialized structure.
// This is used for legacy keys, imported from PEM.
func generateKeyID(key jwk.Key) (string, error) {
// Export the public key and serialize it to PKIX (not in a PEM block)
// This is for backwards-compatibility with the algorithm used before the switch to JWK
pubKey, err := key.PublicKey()
if err != nil {
return "", fmt.Errorf("failed to get public key: %w", err)
}
var pubKeyRaw any
err = jwk.Export(pubKey, &pubKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export public key: %w", err)
}
pubASN1, err := x509.MarshalPKIXPublicKey(pubKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to marshal public key: %w", err)
}
// Compute SHA-256 hash of the public key
hash := sha256.New()
hash.Write(pubASN1)
hashed := hash.Sum(nil)
// Truncate the hash to the first 8 bytes for a shorter Key ID
shortHash := hashed[:8]
// Return Base64 encoded truncated hash as Key ID
return base64.RawURLEncoding.EncodeToString(shortHash), nil
}

View File

@@ -0,0 +1,190 @@
package bootstrap
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
"path/filepath"
"testing"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
func TestMigrateKey(t *testing.T) {
// Create a temporary directory for testing
tempDir := t.TempDir()
t.Run("no keys exist", func(t *testing.T) {
// Test when no keys exist
err := migrateKeyInternal(tempDir)
require.NoError(t, err)
})
t.Run("jwk already exists", func(t *testing.T) {
// Create a JWK file
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
key, err := createTestRSAKey()
require.NoError(t, err)
err = service.SaveKeyJWK(key, jwkPath)
require.NoError(t, err)
// Run migration - should do nothing
err = migrateKeyInternal(tempDir)
require.NoError(t, err)
// Check the file still exists
exists, err := utils.FileExists(jwkPath)
require.NoError(t, err)
assert.True(t, exists)
// Delete for next test
err = os.Remove(jwkPath)
require.NoError(t, err)
})
t.Run("migrate pem to jwk", func(t *testing.T) {
// Create a PEM file
pemPath := filepath.Join(tempDir, privateKeyFilePem)
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
// Generate RSA key and save as PEM
createRSAPrivateKeyPEM(t, pemPath)
// Run migration
err := migrateKeyInternal(tempDir)
require.NoError(t, err)
// Check PEM file is gone
exists, err := utils.FileExists(pemPath)
require.NoError(t, err)
assert.False(t, exists)
// Check JWK file exists
exists, err = utils.FileExists(jwkPath)
require.NoError(t, err)
assert.True(t, exists)
// Verify the JWK can be loaded
data, err := os.ReadFile(jwkPath)
require.NoError(t, err)
_, err = jwk.ParseKey(data)
require.NoError(t, err)
})
}
func TestLoadKeyPEM(t *testing.T) {
// Create a temporary directory for testing
tempDir := t.TempDir()
t.Run("successfully load PEM key", func(t *testing.T) {
pemPath := filepath.Join(tempDir, "test_key.pem")
// Generate RSA key and save as PEM
createRSAPrivateKeyPEM(t, pemPath)
// Load the key
key, err := loadKeyPEM(pemPath)
require.NoError(t, err)
// Verify key properties
assert.NotEmpty(t, key)
// Check key ID is set
var keyID string
err = key.Get(jwk.KeyIDKey, &keyID)
assert.NoError(t, err)
assert.NotEmpty(t, keyID)
// Check algorithm is set
var alg jwa.SignatureAlgorithm
err = key.Get(jwk.AlgorithmKey, &alg)
assert.NoError(t, err)
assert.NotEmpty(t, alg)
// Check key usage is set
var keyUsage string
err = key.Get(jwk.KeyUsageKey, &keyUsage)
assert.NoError(t, err)
assert.Equal(t, service.KeyUsageSigning, keyUsage)
})
t.Run("file not found", func(t *testing.T) {
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
assert.Error(t, err)
assert.Nil(t, key)
})
t.Run("invalid file content", func(t *testing.T) {
invalidPath := filepath.Join(tempDir, "invalid.pem")
err := os.WriteFile(invalidPath, []byte("not a valid PEM"), 0600)
require.NoError(t, err)
key, err := loadKeyPEM(invalidPath)
assert.Error(t, err)
assert.Nil(t, key)
})
}
func TestGenerateKeyID(t *testing.T) {
key, err := createTestRSAKey()
require.NoError(t, err)
keyID, err := generateKeyID(key)
require.NoError(t, err)
// Key ID should be non-empty
assert.NotEmpty(t, keyID)
// Generate another key ID to prove it depends on the key
key2, err := createTestRSAKey()
require.NoError(t, err)
keyID2, err := generateKeyID(key2)
require.NoError(t, err)
// The two key IDs should be different
assert.NotEqual(t, keyID, keyID2)
}
// Helper functions
func createTestRSAKey() (jwk.Key, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
key, err := jwk.Import(privateKey)
if err != nil {
return nil, err
}
return key, nil
}
// createRSAPrivateKeyPEM generates an RSA private key and returns its PEM-encoded form
func createRSAPrivateKeyPEM(t *testing.T, pemPath string) ([]byte, *rsa.PrivateKey) {
// Generate RSA key
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
// Encode to PEM format
pemData := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
})
err = os.WriteFile(pemPath, pemData, 0600)
require.NoError(t, err)
return pemData, privKey
}

View File

@@ -30,13 +30,13 @@ type WellKnownController struct {
// @Success 200 {object} object "{ \"keys\": []interface{} }"
// @Router /.well-known/jwks.json [get]
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
jwk, err := wkc.jwtService.GetJWK()
jwks, err := wkc.jwtService.GetPublicJWKSAsJSON()
if err != nil {
c.Error(err)
return
}
c.JSON(http.StatusOK, gin.H{"keys": []interface{}{jwk}})
c.Data(http.StatusOK, "application/json; charset=utf-8", jwks)
}
// openIDConfigurationHandler godoc

View File

@@ -3,14 +3,12 @@ package service
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/big"
"os"
"path/filepath"
"slices"
@@ -18,80 +16,148 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
const (
privateKeyFile = "jwt_private_key.pem"
// Path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// Size, in bits, of the RSA key to generate if none is found
RsaKeySize = 2048
// Usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
)
type JwtService struct {
privateKey *rsa.PrivateKey
privateKey jwk.Key
keyId string
appConfigService *AppConfigService
jwksEncoded []byte
}
func NewJwtService(appConfigService *AppConfigService) *JwtService {
service := &JwtService{
appConfigService: appConfigService,
}
service := &JwtService{}
// Ensure keys are generated or loaded
if err := service.loadOrGenerateKey(common.EnvConfig.KeysPath); err != nil {
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
}
return service
}
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
s.appConfigService = appConfigService
// Ensure keys are generated or loaded
return s.loadOrGenerateKey(keysPath)
}
type AccessTokenJWTClaims struct {
jwt.RegisteredClaims
IsAdmin bool `json:"isAdmin,omitempty"`
}
type JWK struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
}
// loadOrGenerateKey loads RSA keys from the given paths or generates them if they do not exist.
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
var key jwk.Key
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKey(keysPath); err != nil {
return fmt.Errorf("can't generate key: %w", err)
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
}
if ok {
key, err = s.loadKeyJWK(jwkPath)
if err != nil {
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
}
// Set the key, and we are done
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
return nil
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
// If we are here, we need to generate a new key
key, err = s.generateNewRSAKey()
if err != nil {
return fmt.Errorf("can't read jwt private key: %w", err)
}
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return fmt.Errorf("can't parse jwt private key: %w", err)
return fmt.Errorf("failed to generate new private key: %w", err)
}
err = s.SetKey(privateKey)
// Set the key in the object, which also validates it
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
// Save the key as JWK
err = SaveKeyJWK(s.privateKey, jwkPath)
if err != nil {
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
return nil
}
func (s *JwtService) SetKey(privateKey *rsa.PrivateKey) (err error) {
func ValidateKey(privateKey jwk.Key) error {
// Validate the loaded key
err := privateKey.Validate()
if err != nil {
return fmt.Errorf("key object is invalid: %w", err)
}
keyID, ok := privateKey.KeyID()
if !ok || keyID == "" {
return errors.New("key object does not contain a key ID")
}
usage, ok := privateKey.KeyUsage()
if !ok || usage != KeyUsageSigning {
return errors.New("key object is not valid for signing")
}
ok, err = jwk.IsPrivateKey(privateKey)
if err != nil || !ok {
return errors.New("key object is not a private key")
}
return nil
}
func (s *JwtService) SetKey(privateKey jwk.Key) error {
// Validate the loaded key
err := ValidateKey(privateKey)
if err != nil {
return fmt.Errorf("private key is not valid: %w", err)
}
// Set the private key in the object
s.privateKey = privateKey
s.keyId, err = s.generateKeyID()
// Create and encode a JWKS containing the public key
publicKey, err := s.GetPublicJWK()
if err != nil {
return fmt.Errorf("can't generate key ID: %w", err)
return fmt.Errorf("failed to get public JWK: %w", err)
}
jwks := jwk.NewSet()
err = jwks.AddKey(publicKey)
if err != nil {
return fmt.Errorf("failed to add public key to JWKS: %w", err)
}
s.jwksEncoded, err = json.Marshal(jwks)
if err != nil {
return fmt.Errorf("failed to encode JWKS to JSON: %w", err)
}
return nil
@@ -112,12 +178,23 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
return token.SignedString(s.privateKey)
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err)
}
signed, err := token.SignedString(privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return signed, nil
}
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return &s.privateKey.PublicKey, nil
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (any, error) {
return s.getPublicKeyRaw()
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
@@ -135,12 +212,12 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaim
}
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
claims := jwt.MapClaims{
"aud": clientID,
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
"iat": jwt.NewNumericDate(time.Now()),
"iss": common.EnvConfig.AppURL,
}
// Initialize with capacity for userClaims, + 4 fixed claims, + 2 claims which may be set in some cases, to avoid re-allocations
claims := make(jwt.MapClaims, len(userClaims)+6)
claims["aud"] = clientID
claims["exp"] = jwt.NewNumericDate(time.Now().Add(1 * time.Hour))
claims["iat"] = jwt.NewNumericDate(time.Now())
claims["iss"] = common.EnvConfig.AppURL
for k, v := range userClaims {
claims[k] = v
@@ -153,43 +230,18 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = s.keyId
return token.SignedString(s.privateKey)
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
claim := jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{clientID},
Issuer: common.EnvConfig.AppURL,
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
return token.SignedString(s.privateKey)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return &s.privateKey.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
return token.SignedString(privateKeyRaw)
}
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return &s.privateKey.PublicKey, nil
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
return s.getPublicKeyRaw()
}, jwt.WithIssuer(common.EnvConfig.AppURL))
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
@@ -204,78 +256,180 @@ func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, e
return claims, nil
}
// GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) {
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
claim := jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{clientID},
Issuer: common.EnvConfig.AppURL,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err)
}
return token.SignedString(privateKeyRaw)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
return s.getPublicKeyRaw()
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
}
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
if s.privateKey == nil {
return JWK{}, errors.New("public key is not initialized")
return nil, errors.New("key is not initialized")
}
jwk := JWK{
Kid: s.keyId,
Kty: "RSA",
Use: "sig",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.privateKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()),
}
return jwk, nil
}
// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key.
func (s *JwtService) generateKeyID() (string, error) {
pubASN1, err := x509.MarshalPKIXPublicKey(&s.privateKey.PublicKey)
pubKey, err := s.privateKey.PublicKey()
if err != nil {
return "", fmt.Errorf("failed to marshal public key: %w", err)
return nil, fmt.Errorf("failed to get public key: %w", err)
}
// Compute SHA-256 hash of the public key
hash := sha256.New()
hash.Write(pubASN1)
hashed := hash.Sum(nil)
EnsureAlgInKey(pubKey)
// Truncate the hash to the first 8 bytes for a shorter Key ID
shortHash := hashed[:8]
// Return Base64 encoded truncated hash as Key ID
return base64.RawURLEncoding.EncodeToString(shortHash), nil
return pubKey, nil
}
// generateKey generates a new RSA key and saves it to the specified path.
func (s *JwtService) generateKey(keysPath string) error {
if err := os.MkdirAll(keysPath, 0700); err != nil {
return fmt.Errorf("failed to create directories for keys: %w", err)
// GetPublicJWKSAsJSON returns the JSON Web Key Set (JWKS) for the public key, encoded as JSON.
// The value is cached since the key is static.
func (s *JwtService) GetPublicJWKSAsJSON() ([]byte, error) {
if len(s.jwksEncoded) == 0 {
return nil, errors.New("key is not initialized")
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
return s.jwksEncoded, nil
}
func (s *JwtService) getPublicKeyRaw() (any, error) {
pubKey, err := s.privateKey.PublicKey()
if err != nil {
return fmt.Errorf("failed to generate private key: %w", err)
return nil, fmt.Errorf("failed to get public key: %w", err)
}
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
return err
var pubKeyRaw any
err = jwk.Export(pubKey, &pubKeyRaw)
if err != nil {
return nil, fmt.Errorf("failed to export raw public key: %w", err)
}
return nil
return pubKeyRaw, nil
}
// savePEMKey saves a PEM encoded key to a file.
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
keyFile, err := os.Create(path)
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
return key, nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
func EnsureAlgInKey(key jwk.Key) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
case jwa.OKP():
// Default to EdDSA for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
}
}
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
}
// Import the raw key
return importRawKey(rawKey)
}
func importRawKey(rawKey any) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key)
return key, err
}
// SaveKeyJWK saves a JWK to a file
func SaveKeyJWK(key jwk.Key, path string) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
}
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file: %w", err)
}
defer keyFile.Close()
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: keyType,
Bytes: keyBytes,
})
if _, err := keyFile.Write(keyPEM); err != nil {
// Write the JSON file to disk
enc := json.NewEncoder(keyFile)
enc.SetEscapeHTML(false)
err = enc.Encode(key)
if err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// generateRandomKeyID generates a random key ID.
// It is used for newly-generated keys
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}

View File

@@ -0,0 +1,546 @@
package service
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"os"
"path/filepath"
"testing"
"time"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
)
func TestJwtService_Init(t *testing.T) {
t.Run("should generate new key when none exists", func(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Create a mock AppConfigService
appConfigService := &AppConfigService{}
// Initialize the JWT service
service := &JwtService{}
err := service.init(appConfigService, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Verify the private key was set
require.NotNil(t, service.privateKey, "Private key should be set")
// Verify the key has been saved to disk as JWK
jwkPath := filepath.Join(tempDir, PrivateKeyFile)
_, err = os.Stat(jwkPath)
assert.NoError(t, err, "JWK file should exist")
// Verify the generated key is valid
keyData, err := os.ReadFile(jwkPath)
require.NoError(t, err)
key, err := jwk.ParseKey(keyData)
require.NoError(t, err)
// Key should have required properties
keyID, ok := key.KeyID()
assert.True(t, ok, "Key should have a key ID")
assert.NotEmpty(t, keyID)
keyUsage, ok := key.KeyUsage()
assert.True(t, ok, "Key should have a key usage")
assert.Equal(t, "sig", keyUsage)
})
t.Run("should load existing JWK key", func(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// First create a service to generate a key
firstService := &JwtService{}
err := firstService.init(&AppConfigService{}, tempDir)
require.NoError(t, err)
// Get the key ID of the first service
origKeyID, ok := firstService.privateKey.KeyID()
require.True(t, ok)
// Now create a new service that should load the existing key
secondService := &JwtService{}
err = secondService.init(&AppConfigService{}, tempDir)
require.NoError(t, err)
// Verify the loaded key has the same ID as the original
loadedKeyID, ok := secondService.privateKey.KeyID()
require.True(t, ok)
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
})
t.Run("should load existing JWK for EC keys", func(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Create a new JWK and save it to disk
origKeyID := createECKeyJWK(t, tempDir)
// Now create a new service that should load the existing key
svc := &JwtService{}
err := svc.init(&AppConfigService{}, tempDir)
require.NoError(t, err)
// Verify the loaded key has the same ID as the original
loadedKeyID, ok := svc.privateKey.KeyID()
require.True(t, ok)
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
})
}
func TestJwtService_GetPublicJWK(t *testing.T) {
t.Run("returns public key when private key is initialized", func(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Create a JWT service with initialized key
service := &JwtService{}
err := service.init(&AppConfigService{}, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
publicKey, err := service.GetPublicJWK()
require.NoError(t, err, "GetPublicJWK should not return an error when private key is initialized")
// Verify the returned key is valid
require.NotNil(t, publicKey, "Public key should not be nil")
// Validate it's actually a public key
isPrivate, err := jwk.IsPrivateKey(publicKey)
require.NoError(t, err)
assert.False(t, isPrivate, "Returned key should be a public key")
// Check that key has required properties
keyID, ok := publicKey.KeyID()
require.True(t, ok, "Public key should have a key ID")
assert.NotEmpty(t, keyID, "Key ID should not be empty")
alg, ok := publicKey.Algorithm()
require.True(t, ok, "Public key should have an algorithm")
assert.Equal(t, "RS256", alg.String(), "Algorithm should be RS256")
})
t.Run("returns public key when ECDSA private key is initialized", func(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Create an ECDSA key and save it as JWK
originalKeyID := createECKeyJWK(t, tempDir)
// Create a JWT service that loads the ECDSA key
service := &JwtService{}
err := service.init(&AppConfigService{}, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
publicKey, err := service.GetPublicJWK()
require.NoError(t, err, "GetPublicJWK should not return an error when private key is initialized")
// Verify the returned key is valid
require.NotNil(t, publicKey, "Public key should not be nil")
// Validate it's actually a public key
isPrivate, err := jwk.IsPrivateKey(publicKey)
require.NoError(t, err)
assert.False(t, isPrivate, "Returned key should be a public key")
// Check that key has required properties
keyID, ok := publicKey.KeyID()
require.True(t, ok, "Public key should have a key ID")
assert.Equal(t, originalKeyID, keyID, "Key ID should match the original key ID")
// Check that the key type is EC
assert.Equal(t, "EC", publicKey.KeyType().String(), "Key type should be EC")
// Check that the algorithm is ES256
alg, ok := publicKey.Algorithm()
require.True(t, ok, "Public key should have an algorithm")
assert.Equal(t, "ES256", alg.String(), "Algorithm should be ES256")
})
t.Run("returns error when private key is not initialized", func(t *testing.T) {
// Create a service with nil private key
service := &JwtService{
privateKey: nil,
}
// Try to get the JWK
publicKey, err := service.GetPublicJWK()
// Verify it returns an error
require.Error(t, err, "GetPublicJWK should return an error when private key is nil")
assert.Contains(t, err.Error(), "key is not initialized", "Error message should indicate key is not initialized")
assert.Nil(t, publicKey, "Public key should be nil when there's an error")
})
}
func TestGenerateVerifyAccessToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
t.Run("generates token for regular user", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
user := model.User{
Base: model.Base{
ID: "user123",
},
Email: "user@example.com",
IsAdmin: false,
}
// Generate a token
tokenString, err := service.GenerateAccessToken(user)
require.NoError(t, err, "Failed to generate access token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated token")
// Check the claims
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
assert.Equal(t, false, claims.IsAdmin, "IsAdmin should be false")
assert.Contains(t, claims.Audience, "https://test.example.com", "Audience should contain the app URL")
// Check token expiration time is approximately 60 minutes from now
expectedExp := time.Now().Add(60 * time.Minute)
tokenExp := claims.ExpiresAt.Time
timeDiff := expectedExp.Sub(tokenExp).Minutes()
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 60 minutes")
})
t.Run("generates token for admin user", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test admin user
adminUser := model.User{
Base: model.Base{
ID: "admin123",
},
Email: "admin@example.com",
IsAdmin: true,
}
// Generate a token
tokenString, err := service.GenerateAccessToken(adminUser)
require.NoError(t, err, "Failed to generate access token")
// Verify the token
claims, err := service.VerifyAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated token")
// Check the IsAdmin claim is true
assert.Equal(t, true, claims.IsAdmin, "IsAdmin should be true for admin users")
assert.Equal(t, adminUser.ID, claims.Subject, "Token subject should match admin ID")
})
t.Run("uses session duration from config", func(t *testing.T) {
// Create a JWT service with a different session duration
customMockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes
},
}
service := &JwtService{}
err := service.init(customMockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
user := model.User{
Base: model.Base{
ID: "user456",
},
}
// Generate a token
tokenString, err := service.GenerateAccessToken(user)
require.NoError(t, err, "Failed to generate access token")
// Verify the token
claims, err := service.VerifyAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated token")
// Check token expiration time is approximately 30 minutes from now
expectedExp := time.Now().Add(30 * time.Minute)
tokenExp := claims.ExpiresAt.Time
timeDiff := expectedExp.Sub(tokenExp).Minutes()
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 30 minutes")
})
}
func TestGenerateVerifyIdToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
t.Run("generates and verifies ID token with standard claims", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims
userClaims := map[string]interface{}{
"sub": "user123",
"name": "Test User",
"email": "user@example.com",
}
const clientID = "test-client-123"
// Generate a token
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
require.NoError(t, err, "Failed to generate ID token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyIdToken(tokenString)
require.NoError(t, err, "Failed to verify generated ID token")
// Check the claims
assert.Equal(t, "user123", claims.Subject, "Token subject should match user ID")
assert.Contains(t, claims.Audience, clientID, "Audience should contain the client ID")
assert.Equal(t, common.EnvConfig.AppURL, claims.Issuer, "Issuer should match app URL")
// Check token expiration time is approximately 1 hour from now
expectedExp := time.Now().Add(1 * time.Hour)
tokenExp := claims.ExpiresAt.Time
timeDiff := expectedExp.Sub(tokenExp).Minutes()
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
})
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims with nonce
userClaims := map[string]interface{}{
"sub": "user456",
"name": "Another User",
}
const clientID = "test-client-456"
nonce := "random-nonce-value"
// Generate a token with nonce
tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce)
require.NoError(t, err, "Failed to generate ID token with nonce")
// Parse the token manually to check nonce
publicKey, err := service.GetPublicJWK()
require.NoError(t, err, "Failed to get public key")
token, err := jwt.Parse([]byte(tokenString), jwt.WithKey(jwa.RS256(), publicKey))
require.NoError(t, err, "Failed to parse token")
var tokenNonce string
err = token.Get("nonce", &tokenNonce)
require.NoError(t, err, "Failed to get claims")
assert.Equal(t, nonce, tokenNonce, "Token should contain the correct nonce")
})
t.Run("fails verification with incorrect issuer", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Generate a token with standard claims
userClaims := map[string]interface{}{
"sub": "user789",
}
tokenString, err := service.GenerateIDToken(userClaims, "client-789", "")
require.NoError(t, err, "Failed to generate ID token")
// Temporarily change the app URL to simulate wrong issuer
common.EnvConfig.AppURL = "https://wrong-issuer.com"
// Verify should fail due to issuer mismatch
_, err = service.VerifyIdToken(tokenString)
assert.Error(t, err, "Verification should fail with incorrect issuer")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
}
func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
user := model.User{
Base: model.Base{
ID: "user123",
},
Email: "user@example.com",
}
const clientID = "test-client-123"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token")
// Check the claims
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
assert.Contains(t, claims.Audience, clientID, "Audience should contain the client ID")
assert.Equal(t, common.EnvConfig.AppURL, claims.Issuer, "Issuer should match app URL")
// Check token expiration time is approximately 1 hour from now
expectedExp := time.Now().Add(1 * time.Hour)
tokenExp := claims.ExpiresAt.Time
timeDiff := expectedExp.Sub(tokenExp).Minutes()
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
})
t.Run("fails verification for expired token", func(t *testing.T) {
// Create a JWT service with a mock function to generate an expired token
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
user := model.User{
Base: model.Base{
ID: "user456",
},
}
const clientID = "test-client-456"
// Generate a token using JWT directly to create an expired token
token, err := jwt.NewBuilder().
Subject(user.ID).
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
IssuedAt(time.Now().Add(-2 * time.Hour)).
Audience([]string{clientID}).
Issuer(common.EnvConfig.AppURL).
Build()
require.NoError(t, err, "Failed to build token")
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
require.NoError(t, err, "Failed to sign token")
// Verify should fail due to expiration
_, err = service.VerifyOauthAccessToken(string(signed))
assert.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
t.Run("fails verification with invalid signature", func(t *testing.T) {
// Create two JWT services with different keys
service1 := &JwtService{}
err := service1.init(mockConfig, t.TempDir()) // Use a different temp dir
require.NoError(t, err, "Failed to initialize first JWT service")
service2 := &JwtService{}
err = service2.init(mockConfig, t.TempDir()) // Use a different temp dir
require.NoError(t, err, "Failed to initialize second JWT service")
// Create a test user
user := model.User{
Base: model.Base{
ID: "user789",
},
}
const clientID = "test-client-789"
// Generate a token with the first service
tokenString, err := service1.GenerateOauthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token")
// Verify with the second service should fail due to different keys
_, err = service2.VerifyOauthAccessToken(tokenString)
assert.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
}
func createECKeyJWK(t *testing.T, path string) string {
t.Helper()
// Generate a new P-256 ECDSA key
privateKeyRaw, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err, "Failed to generate ECDSA key")
// Import as JWK and save to disk
privateKey, err := importRawKey(privateKeyRaw)
require.NoError(t, err, "Failed to import private key")
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
require.NoError(t, err, "Failed to save key")
kid, _ := privateKey.KeyID()
require.NotEmpty(t, kid, "Key ID must be set")
return kid
}

View File

@@ -4,7 +4,6 @@ import (
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"log"
"os"
@@ -12,14 +11,15 @@ import (
"time"
"github.com/fxamacker/cbor/v2"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/resources"
"github.com/go-webauthn/webauthn/protocol"
"github.com/lestrrat-go/jwx/v3/jwk"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/resources"
)
type TestService struct {
@@ -304,38 +304,9 @@ func (s *TestService) ResetAppConfig() error {
}
func (s *TestService) SetJWTKeys() {
privateKeyString := `-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEAyaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B
83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC+585UXacoJ0c
hUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl/4EDDTO8HwawTjwkPo
QlRzeByhlvGPVvwgB3Fn93B8QJ/cZhXKxJvjjrC/8Pk76heC/ntEMru71Ix77BoC
3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeO
Zl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJwIDAQABAoIBAQCa8wNZJ08+9y6b
RzSIQcTaBuq1XY0oyYvCuX0ToruDyVNX3lJ48udb9vDIw9XsQans9CTeXXsjldGE
WPN7sapOcUg6ArMyJqc+zuO/YQu0EwYrTE48BOC7WIZvvTFnq9y+4R9HJjd0nTOv
iOlR1W5fAqbH2srgh1mfZ0UIp+9K6ymoinPXVGEXUAuuoMuTEZW/tnA2HT9WEllT
2FyMbmXrFzutAQqk9GRmnQh2OQZLxnQWyShVqJEhYBtm6JUUH1YJbyTVzMLgdBM8
ukgjTVtRDHaW51ubRSVdGBVT2m1RRtTsYAiZCpM5bwt88aSUS9yDOUiVH+irDg/3
IHEuL7IxAoGBAP2MpXPXtOwinajUQ9hKLDAtpq4axGvY+aGP5dNEMsuPo5ggOfUP
b4sqr73kaNFO3EbxQOQVoFjehhi4dQxt1/kAala9HZ5N7s26G2+eUWFF8jy7gWSN
qusNqGrG4g8D3WOyqZFb/x/m6SE0Jcg7zvIYbnAOq1Fexeik0Fc/DNzLAoGBAMua
d4XIfu4ydtU5AIaf1ZNXywgLg+LWxK8ELNqH/Y2vLAeIiTrOVp+hw9z+zHPD5cnu
6mix783PCOYNLTylrwtAz3fxSz14lsDFQM3ntzVF/6BniTTkKddctcPyqnTvamah
0hD2dzXBS/0mTBYIIMYTNbs0Yj87FTdJZw/+qa2VAoGBAKbzQkp54W6PCIMPabD0
fg4nMRZ5F5bv4seIKcunn068QPs9VQxQ4qCfNeLykDYqGA86cgD9YHzD4UZLxv6t
IUWbCWod0m/XXwPlpIUlmO5VEUD+MiAUzFNDxf6xAE7ku5UXImJNUjseX6l2Xd5v
yz9L6QQuFI5aujQKugiIwp5rAoGATtUVGCCkPNgfOLmkYXu7dxxUCV5kB01+xAEK
2OY0n0pG8vfDophH4/D/ZC7nvJ8J9uDhs/3JStexq1lIvaWtG99RNTChIEDzpdn6
GH9yaVcb/eB4uJjrNm64FhF8PGCCwxA+xMCZMaARKwhMB2/IOMkxUbWboL3gnhJ2
rDO/QO0CgYEA2Grt6uXHm61ji3xSdkBWNtUnj19vS1+7rFJp5SoYztVQVThf/W52
BAiXKBdYZDRVoItC/VS2NvAOjeJjhYO/xQ/q3hK7MdtuXfEPpLnyXKkmWo3lrJ26
wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI=
-----END RSA PRIVATE KEY-----
`
block, _ := pem.Decode([]byte(privateKeyString))
privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes)
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
s.jwtService.SetKey(privateKey)
}

View File

@@ -78,3 +78,15 @@ func SaveFile(file *multipart.FileHeader, dst string) error {
_, err = io.Copy(out, src)
return err
}
// FileExists returns true if a file exists on disk and is a regular file
func FileExists(path string) (bool, error) {
s, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
err = nil
}
return false, err
}
return !s.IsDir(), nil
}