mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 01:11:16 +03:00
325 lines
7.9 KiB
Go
325 lines
7.9 KiB
Go
package jwk
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/ed25519"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"testing"
|
|
|
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestGenerateKey(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
alg string
|
|
crv string
|
|
expectError bool
|
|
expectedAlg jwa.SignatureAlgorithm
|
|
}{
|
|
{
|
|
name: "RS256",
|
|
alg: jwa.RS256().String(),
|
|
crv: "",
|
|
expectError: false,
|
|
expectedAlg: jwa.RS256(),
|
|
},
|
|
{
|
|
name: "RS384",
|
|
alg: jwa.RS384().String(),
|
|
crv: "",
|
|
expectError: false,
|
|
expectedAlg: jwa.RS384(),
|
|
},
|
|
// Skip the RS512 test as generating a RSA-4096 key can take some time
|
|
/* {
|
|
name: "RS512",
|
|
alg: jwa.RS512().String(),
|
|
crv: "",
|
|
expectError: false,
|
|
expectedAlg: jwa.RS512(),
|
|
}, */
|
|
{
|
|
name: "ES256",
|
|
alg: jwa.ES256().String(),
|
|
crv: jwa.P256().String(),
|
|
expectError: false,
|
|
expectedAlg: jwa.ES256(),
|
|
},
|
|
{
|
|
name: "ES384",
|
|
alg: jwa.ES384().String(),
|
|
crv: jwa.P384().String(),
|
|
expectError: false,
|
|
expectedAlg: jwa.ES384(),
|
|
},
|
|
{
|
|
name: "ES512",
|
|
alg: jwa.ES512().String(),
|
|
crv: jwa.P521().String(),
|
|
expectError: false,
|
|
expectedAlg: jwa.ES512(),
|
|
},
|
|
{
|
|
name: "EdDSA with Ed25519",
|
|
alg: jwa.EdDSA().String(),
|
|
crv: jwa.Ed25519().String(),
|
|
expectError: false,
|
|
expectedAlg: jwa.EdDSA(),
|
|
},
|
|
{
|
|
name: "EdDSA with unsupported curve",
|
|
alg: jwa.EdDSA().String(),
|
|
crv: "unsupported",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "Unsupported algorithm",
|
|
alg: "UNSUPPORTED",
|
|
crv: "",
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
key, err := GenerateKey(tt.alg, tt.crv)
|
|
|
|
if tt.expectError {
|
|
require.Error(t, err)
|
|
assert.Nil(t, key)
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, key)
|
|
|
|
// Verify the algorithm is set correctly
|
|
alg, ok := key.Algorithm()
|
|
require.True(t, ok, "algorithm should be set in the key")
|
|
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
|
|
|
// Verify other required fields are set
|
|
kid, ok := key.KeyID()
|
|
assert.True(t, ok, "key ID should be set")
|
|
assert.NotEmpty(t, kid, "key ID should not be empty")
|
|
|
|
usage, ok := key.KeyUsage()
|
|
assert.True(t, ok, "key usage should be set")
|
|
assert.Equal(t, KeyUsageSigning, usage)
|
|
|
|
var crv any
|
|
_ = key.Get("crv", &crv)
|
|
|
|
// Verify key type matches expected algorithm
|
|
switch tt.expectedAlg {
|
|
case jwa.RS256(), jwa.RS384(), jwa.RS512():
|
|
assert.Equal(t, jwa.RSA(), key.KeyType())
|
|
assert.Nil(t, crv)
|
|
case jwa.ES256(), jwa.ES384(), jwa.ES512():
|
|
assert.Equal(t, jwa.EC(), key.KeyType())
|
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
|
_ = assert.NotNil(t, crv) &&
|
|
assert.True(t, ok) &&
|
|
assert.Equal(t, tt.crv, eca.String())
|
|
case jwa.EdDSA():
|
|
assert.Equal(t, jwa.OKP(), key.KeyType())
|
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
|
_ = assert.NotNil(t, crv) &&
|
|
assert.True(t, ok) &&
|
|
assert.Equal(t, tt.crv, eca.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEnsureAlgInKey(t *testing.T) {
|
|
// Generate an RSA-2048 key
|
|
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("does not change alg already set", func(t *testing.T) {
|
|
// Import the RSA key
|
|
key, err := jwk.Import(rsaKey)
|
|
require.NoError(t, err)
|
|
|
|
// Pre-set the algorithm
|
|
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
|
|
|
// Call EnsureAlgInKey with a different algorithm
|
|
EnsureAlgInKey(key, jwa.RS384().String(), "")
|
|
|
|
// Verify the algorithm wasn't changed
|
|
alg, ok := key.Algorithm()
|
|
require.True(t, ok)
|
|
assert.Equal(t, jwa.RS256().String(), alg.String())
|
|
})
|
|
|
|
t.Run("set algorithm to explicitly-provided value", func(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
keyGen func() (any, error)
|
|
alg string
|
|
crv string
|
|
expectedAlg jwa.SignatureAlgorithm
|
|
expectedCrv string
|
|
}{
|
|
{
|
|
name: "RSA key with RS384",
|
|
keyGen: func() (any, error) {
|
|
return rsaKey, nil
|
|
},
|
|
alg: jwa.RS384().String(),
|
|
crv: "",
|
|
expectedAlg: jwa.RS384(),
|
|
expectedCrv: "",
|
|
},
|
|
{
|
|
name: "ECDSA key with ES384",
|
|
keyGen: func() (any, error) {
|
|
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
},
|
|
alg: jwa.ES384().String(),
|
|
crv: jwa.P384().String(),
|
|
expectedAlg: jwa.ES384(),
|
|
expectedCrv: jwa.P384().String(),
|
|
},
|
|
{
|
|
name: "Ed25519 key with EdDSA",
|
|
keyGen: func() (any, error) {
|
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
return priv, err
|
|
},
|
|
alg: jwa.EdDSA().String(),
|
|
crv: jwa.Ed25519().String(),
|
|
expectedAlg: jwa.EdDSA(),
|
|
expectedCrv: jwa.Ed25519().String(),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
rawKey, err := tt.keyGen()
|
|
require.NoError(t, err)
|
|
|
|
key, err := jwk.Import(rawKey)
|
|
require.NoError(t, err)
|
|
|
|
// Ensure no algorithm is set initially
|
|
_, ok := key.Algorithm()
|
|
assert.False(t, ok)
|
|
|
|
// Call EnsureAlgInKey
|
|
EnsureAlgInKey(key, tt.alg, tt.crv)
|
|
|
|
// Verify the algorithm was set correctly
|
|
alg, ok := key.Algorithm()
|
|
require.True(t, ok)
|
|
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
|
|
|
// Verify curve if expected
|
|
if tt.expectedCrv != "" {
|
|
var crv any
|
|
_ = key.Get("crv", &crv)
|
|
require.NotNil(t, crv)
|
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
|
require.True(t, ok)
|
|
assert.Equal(t, tt.expectedCrv, eca.String())
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("set default algorithms if not present", func(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
keyGen func() (any, error)
|
|
expectedAlg jwa.SignatureAlgorithm
|
|
expectedCrv string
|
|
}{
|
|
{
|
|
name: "RSA key defaults to RS256",
|
|
keyGen: func() (any, error) {
|
|
return rsaKey, nil
|
|
},
|
|
expectedAlg: jwa.RS256(),
|
|
expectedCrv: "",
|
|
},
|
|
{
|
|
name: "ECDSA key defaults to ES256 with P256",
|
|
keyGen: func() (any, error) {
|
|
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
},
|
|
expectedAlg: jwa.ES256(),
|
|
expectedCrv: jwa.P256().String(),
|
|
},
|
|
{
|
|
name: "Ed25519 key defaults to EdDSA with Ed25519",
|
|
keyGen: func() (any, error) {
|
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
return priv, err
|
|
},
|
|
expectedAlg: jwa.EdDSA(),
|
|
expectedCrv: jwa.Ed25519().String(),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
rawKey, err := tt.keyGen()
|
|
require.NoError(t, err)
|
|
|
|
key, err := jwk.Import(rawKey)
|
|
require.NoError(t, err)
|
|
|
|
// Ensure no algorithm is set initially
|
|
_, ok := key.Algorithm()
|
|
assert.False(t, ok)
|
|
|
|
// Call EnsureAlgInKey with empty parameters
|
|
EnsureAlgInKey(key, "", "")
|
|
|
|
// Verify the default algorithm was set
|
|
alg, ok := key.Algorithm()
|
|
require.True(t, ok)
|
|
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
|
|
|
// Verify curve if expected
|
|
if tt.expectedCrv != "" {
|
|
var crv any
|
|
_ = key.Get("crv", &crv)
|
|
require.NotNil(t, crv)
|
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
|
require.True(t, ok)
|
|
assert.Equal(t, tt.expectedCrv, eca.String())
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("invalid curve should not set curve parameter", func(t *testing.T) {
|
|
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
key, err := jwk.Import(rsaKey)
|
|
require.NoError(t, err)
|
|
|
|
// Call EnsureAlgInKey with invalid curve
|
|
EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve")
|
|
|
|
// Verify algorithm was set but curve was not
|
|
alg, ok := key.Algorithm()
|
|
require.True(t, ok)
|
|
assert.Equal(t, jwa.RS256().String(), alg.String())
|
|
|
|
var crv any
|
|
_ = key.Get("crv", &crv)
|
|
assert.Nil(t, crv)
|
|
})
|
|
}
|