mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-13 08:43:03 +03:00
169 lines
4.5 KiB
Go
169 lines
4.5 KiB
Go
package jwk
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ecdsa"
|
|
"crypto/ed25519"
|
|
"crypto/elliptic"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha3"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"hash"
|
|
"io"
|
|
|
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
)
|
|
|
|
const (
|
|
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
|
KeyUsageSigning = "sig"
|
|
)
|
|
|
|
// EncodeJWK encodes a jwk.Key to a writable stream.
|
|
func EncodeJWK(w io.Writer, key jwk.Key) error {
|
|
enc := json.NewEncoder(w)
|
|
enc.SetEscapeHTML(false)
|
|
return enc.Encode(key)
|
|
}
|
|
|
|
// EncodeJWKBytes encodes a jwk.Key to a byte slice.
|
|
func EncodeJWKBytes(key jwk.Key) ([]byte, error) {
|
|
b := &bytes.Buffer{}
|
|
err := EncodeJWK(b, key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return b.Bytes(), nil
|
|
}
|
|
|
|
// LoadKeyEncryptionKey loads the key encryption key for JWKs
|
|
func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) {
|
|
// If there's no key, return
|
|
if len(envConfig.EncryptionKey) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// We need a 256-bit key for encryption with AES-GCM-256
|
|
// We use HMAC with SHA3-256 here to derive the key from the one passed as input
|
|
// The key is tied to a specific instance of Pocket ID
|
|
h := hmac.New(func() hash.Hash { return sha3.New256() }, []byte(envConfig.EncryptionKey))
|
|
fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek")
|
|
kek = h.Sum(nil)
|
|
|
|
return kek, nil
|
|
}
|
|
|
|
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
|
// It also populates additional fields such as the key ID, usage, and alg.
|
|
func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) {
|
|
key, err := jwk.Import(rawKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
|
}
|
|
|
|
// Generate the key ID
|
|
kid, err := generateRandomKeyID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
|
}
|
|
_ = key.Set(jwk.KeyIDKey, kid)
|
|
|
|
// Set other required fields
|
|
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
|
EnsureAlgInKey(key, alg, crv)
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// generateRandomKeyID generates a random key ID.
|
|
func generateRandomKeyID() (string, error) {
|
|
buf := make([]byte, 8)
|
|
_, err := io.ReadFull(rand.Reader, buf)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(buf), nil
|
|
}
|
|
|
|
// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type
|
|
func EnsureAlgInKey(key jwk.Key, alg string, crv string) {
|
|
_, ok := key.Algorithm()
|
|
if ok {
|
|
// Algorithm is already set
|
|
return
|
|
}
|
|
|
|
if alg != "" {
|
|
_ = key.Set(jwk.AlgorithmKey, alg)
|
|
if crv != "" {
|
|
eca, ok := jwa.LookupEllipticCurveAlgorithm(crv)
|
|
if ok {
|
|
switch key.KeyType() {
|
|
case jwa.EC():
|
|
_ = key.Set(jwk.ECDSACrvKey, eca)
|
|
case jwa.OKP():
|
|
_ = key.Set(jwk.OKPCrvKey, eca)
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// If we don't have an algorithm, set the default for the key type
|
|
switch key.KeyType() {
|
|
case jwa.RSA():
|
|
// Default to RS256 for RSA keys
|
|
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
|
case jwa.EC():
|
|
// Default to ES256 for ECDSA keys
|
|
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
|
_ = key.Set(jwk.ECDSACrvKey, jwa.P256())
|
|
case jwa.OKP():
|
|
// Default to EdDSA and Ed25519 for OKP keys
|
|
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
|
_ = key.Set(jwk.OKPCrvKey, jwa.Ed25519())
|
|
}
|
|
}
|
|
|
|
// GenerateKey generates a new jwk.Key
|
|
func GenerateKey(alg string, crv string) (key jwk.Key, err error) {
|
|
var rawKey any
|
|
switch alg {
|
|
case jwa.RS256().String():
|
|
rawKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
case jwa.RS384().String():
|
|
rawKey, err = rsa.GenerateKey(rand.Reader, 3072)
|
|
case jwa.RS512().String():
|
|
rawKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
|
case jwa.ES256().String():
|
|
rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
case jwa.ES384().String():
|
|
rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
|
case jwa.ES512().String():
|
|
rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
|
case jwa.EdDSA().String():
|
|
switch crv {
|
|
case jwa.Ed25519().String():
|
|
_, rawKey, err = ed25519.GenerateKey(rand.Reader)
|
|
default:
|
|
return nil, errors.New("unsupported curve for EdDSA algorithm")
|
|
}
|
|
default:
|
|
return nil, errors.New("unsupported key algorithm")
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
|
}
|
|
|
|
// Import the raw key
|
|
return ImportRawKey(rawKey, alg, crv)
|
|
}
|