mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-18 09:13:26 +03:00
feat: add "key-rotate" command (#709)
This commit is contained in:
committed by
GitHub
parent
15ece0ab30
commit
8c8fc2304d
@@ -1,15 +1,9 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
_ "time/tzdata"
|
_ "time/tzdata"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/cmds"
|
"github.com/pocket-id/pocket-id/backend/internal/cmds"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// @title Pocket ID API
|
// @title Pocket ID API
|
||||||
@@ -17,27 +11,5 @@ import (
|
|||||||
// @description.markdown
|
// @description.markdown
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Get the command
|
cmds.Execute()
|
||||||
// By default, this starts the server
|
|
||||||
var cmd string
|
|
||||||
flag.Parse()
|
|
||||||
args := flag.Args()
|
|
||||||
if len(args) > 0 {
|
|
||||||
cmd = args[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
switch cmd {
|
|
||||||
case "version":
|
|
||||||
fmt.Println("pocket-id " + common.Version)
|
|
||||||
case "one-time-access-token":
|
|
||||||
err = cmds.OneTimeAccessToken(args)
|
|
||||||
default:
|
|
||||||
// Start the server
|
|
||||||
err = bootstrap.Bootstrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ require (
|
|||||||
github.com/lestrrat-go/jwx/v3 v3.0.1
|
github.com/lestrrat-go/jwx/v3 v3.0.1
|
||||||
github.com/mileusna/useragent v1.3.5
|
github.com/mileusna/useragent v1.3.5
|
||||||
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
|
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
|
||||||
|
github.com/spf13/cobra v1.9.1
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
go.opentelemetry.io/contrib/exporters/autoexport v0.59.0
|
go.opentelemetry.io/contrib/exporters/autoexport v0.59.0
|
||||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.60.0
|
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.60.0
|
||||||
@@ -69,6 +70,7 @@ require (
|
|||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect
|
||||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/pgx/v5 v5.7.2 // indirect
|
github.com/jackc/pgx/v5 v5.7.2 // indirect
|
||||||
@@ -99,6 +101,7 @@ require (
|
|||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||||
github.com/segmentio/asm v1.2.0 // indirect
|
github.com/segmentio/asm v1.2.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.6 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
|||||||
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
|
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
|
||||||
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
@@ -120,6 +121,8 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9
|
|||||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@@ -227,8 +230,13 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
|||||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||||
|
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
||||||
|
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
|
||||||
|
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||||
|
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
|||||||
107
backend/internal/cmds/key_rotate.go
Normal file
107
backend/internal/cmds/key_rotate.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||||
|
)
|
||||||
|
|
||||||
|
type keyRotateFlags struct {
|
||||||
|
Alg string
|
||||||
|
Crv string
|
||||||
|
Yes bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var flags keyRotateFlags
|
||||||
|
|
||||||
|
keyRotateCmd := &cobra.Command{
|
||||||
|
Use: "key-rotate",
|
||||||
|
Short: "Generates a new token signing key and replaces the current one",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
db := bootstrap.NewDatabase()
|
||||||
|
|
||||||
|
return keyRotate(cmd.Context(), flags, db, &common.EnvConfig)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
keyRotateCmd.Flags().StringVarP(&flags.Alg, "alg", "a", "RS256", "Key algorithm. Supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA")
|
||||||
|
keyRotateCmd.Flags().StringVarP(&flags.Crv, "crv", "c", "", "Curve name when using EdDSA keys. Supported values: Ed25519")
|
||||||
|
keyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation")
|
||||||
|
|
||||||
|
rootCmd.AddCommand(keyRotateCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error {
|
||||||
|
// Validate the flags
|
||||||
|
switch strings.ToUpper(flags.Alg) {
|
||||||
|
case jwa.RS256().String(), jwa.RS384().String(), jwa.RS512().String(),
|
||||||
|
jwa.ES256().String(), jwa.ES384().String(), jwa.ES512().String():
|
||||||
|
// All good, but uppercase it for consistency
|
||||||
|
flags.Alg = strings.ToUpper(flags.Alg)
|
||||||
|
case strings.ToUpper(jwa.EdDSA().String()):
|
||||||
|
// Ensure Crv is set and valid
|
||||||
|
switch strings.ToUpper(flags.Crv) {
|
||||||
|
case strings.ToUpper(jwa.Ed25519().String()):
|
||||||
|
// All good, but ensure consistency in casing
|
||||||
|
flags.Crv = jwa.Ed25519().String()
|
||||||
|
case "":
|
||||||
|
return errors.New("a curve name is required when algorithm is EdDSA")
|
||||||
|
default:
|
||||||
|
return errors.New("unsupported EdDSA curve; supported values: Ed25519")
|
||||||
|
}
|
||||||
|
case "":
|
||||||
|
return errors.New("key algorithm is required")
|
||||||
|
default:
|
||||||
|
return errors.New("unsupported key algorithm; supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !flags.Yes {
|
||||||
|
fmt.Println("WARNING: Rotating the private key will invalidate all existing tokens. Both pocket-id and all client applications will likely need to be restarted.")
|
||||||
|
ok, err := utils.PromptForConfirmation("Confirm")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
fmt.Println("Aborted")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init the services we need
|
||||||
|
appConfigService := service.NewAppConfigService(ctx, db)
|
||||||
|
|
||||||
|
// Get the key provider
|
||||||
|
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfigService.GetDbConfig().InstanceID.Value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get key provider: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new key
|
||||||
|
key, err := jwkutils.GenerateKey(flags.Alg, flags.Crv)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the key
|
||||||
|
err = keyProvider.SaveKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store new key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Key rotated successfully")
|
||||||
|
fmt.Println("Note: if pocket-id is running, you will need to restart it for the new key to be loaded")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
214
backend/internal/cmds/key_rotate_test.go
Normal file
214
backend/internal/cmds/key_rotate_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
|
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||||
|
testingutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyRotate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
flags keyRotateFlags
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid RS256",
|
||||||
|
flags: keyRotateFlags{
|
||||||
|
Alg: "RS256",
|
||||||
|
Yes: true,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid EdDSA with Ed25519",
|
||||||
|
flags: keyRotateFlags{
|
||||||
|
Alg: "EdDSA",
|
||||||
|
Crv: "Ed25519",
|
||||||
|
Yes: true,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid algorithm",
|
||||||
|
flags: keyRotateFlags{
|
||||||
|
Alg: "INVALID",
|
||||||
|
Yes: true,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "unsupported key algorithm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EdDSA without curve",
|
||||||
|
flags: keyRotateFlags{
|
||||||
|
Alg: "EdDSA",
|
||||||
|
Yes: true,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "a curve name is required when algorithm is EdDSA",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty algorithm",
|
||||||
|
flags: keyRotateFlags{
|
||||||
|
Alg: "",
|
||||||
|
Yes: true,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "key algorithm is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Run("file storage", func(t *testing.T) {
|
||||||
|
testKeyRotateWithFileStorage(t, tt.flags, tt.wantErr, tt.errMsg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("database storage", func(t *testing.T) {
|
||||||
|
testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
|
||||||
|
// Create temporary directory for keys
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
keysPath := filepath.Join(tempDir, "keys")
|
||||||
|
err := os.MkdirAll(keysPath, 0755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set up file storage config
|
||||||
|
envConfig := &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: keysPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test database
|
||||||
|
db := testingutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
|
// Initialize app config service and create instance
|
||||||
|
appConfigService := service.NewAppConfigService(t.Context(), db)
|
||||||
|
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
||||||
|
|
||||||
|
// Check if key exists before rotation
|
||||||
|
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Run the key rotation
|
||||||
|
err = keyRotate(t.Context(), flags, db, envConfig)
|
||||||
|
|
||||||
|
if wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if errMsg != "" {
|
||||||
|
require.ErrorContains(t, err, errMsg)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify key was created
|
||||||
|
key, err := keyProvider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, key)
|
||||||
|
|
||||||
|
// Verify the algorithm matches what we requested
|
||||||
|
alg, _ := key.Algorithm()
|
||||||
|
assert.NotEmpty(t, alg)
|
||||||
|
if flags.Alg != "" {
|
||||||
|
expectedAlg := flags.Alg
|
||||||
|
if expectedAlg == "EdDSA" {
|
||||||
|
// EdDSA keys should have the EdDSA algorithm
|
||||||
|
assert.Equal(t, "EdDSA", alg.String())
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, expectedAlg, alg.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
|
||||||
|
// Set up database storage config
|
||||||
|
envConfig := &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "database",
|
||||||
|
EncryptionKey: "test-encryption-key-characters-long",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test database
|
||||||
|
db := testingutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
|
// Initialize app config service and create instance
|
||||||
|
appConfigService := service.NewAppConfigService(t.Context(), db)
|
||||||
|
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
||||||
|
|
||||||
|
// Get key provider
|
||||||
|
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Run the key rotation
|
||||||
|
err = keyRotate(t.Context(), flags, db, envConfig)
|
||||||
|
|
||||||
|
if wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if errMsg != "" {
|
||||||
|
require.ErrorContains(t, err, errMsg)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify key was created
|
||||||
|
key, err := keyProvider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, key)
|
||||||
|
|
||||||
|
// Verify the algorithm matches what we requested
|
||||||
|
alg, _ := key.Algorithm()
|
||||||
|
assert.NotEmpty(t, alg)
|
||||||
|
if flags.Alg != "" {
|
||||||
|
expectedAlg := flags.Alg
|
||||||
|
if expectedAlg == "EdDSA" {
|
||||||
|
// EdDSA keys should have the EdDSA algorithm
|
||||||
|
assert.Equal(t, "EdDSA", alg.String())
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, expectedAlg, alg.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyRotateMultipleAlgorithms(t *testing.T) {
|
||||||
|
algorithms := []struct {
|
||||||
|
alg string
|
||||||
|
crv string
|
||||||
|
}{
|
||||||
|
{"RS256", ""},
|
||||||
|
{"RS384", ""},
|
||||||
|
// Skip RSA-4096 key generation test as it can take a long time
|
||||||
|
// {"RS512", ""},
|
||||||
|
{"ES256", ""},
|
||||||
|
{"ES384", ""},
|
||||||
|
{"ES512", ""},
|
||||||
|
{"EdDSA", "Ed25519"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, algo := range algorithms {
|
||||||
|
t.Run(algo.alg, func(t *testing.T) {
|
||||||
|
// Test with database storage for all algorithms
|
||||||
|
testKeyRotateWithDatabaseStorage(t, keyRotateFlags{
|
||||||
|
Alg: algo.alg,
|
||||||
|
Crv: algo.crv,
|
||||||
|
Yes: true,
|
||||||
|
}, false, "")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,27 +6,22 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// OneTimeAccessToken creates a one-time access token for the given user
|
var oneTimeAccessTokenCmd = &cobra.Command{
|
||||||
// Args must contain the username or email of the user
|
Use: "one-time-access-token [username or email]",
|
||||||
func OneTimeAccessToken(args []string) error {
|
Short: "Generates a one-time access token for the given user",
|
||||||
// Get a context that is canceled when the application is stopping
|
Args: cobra.ExactArgs(1),
|
||||||
ctx := signals.SignalContext(context.Background())
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
// Get the username or email of the user
|
// Get the username or email of the user
|
||||||
// Note length is 2 because the first argument is always the command (one-time-access-token)
|
userArg := args[0]
|
||||||
if len(args) != 2 {
|
|
||||||
return errors.New("missing username or email of user; usage: one-time-access-token <username or email>")
|
|
||||||
}
|
|
||||||
userArg := args[1]
|
|
||||||
|
|
||||||
// Connect to the database
|
// Connect to the database
|
||||||
db := bootstrap.NewDatabase()
|
db := bootstrap.NewDatabase()
|
||||||
@@ -36,7 +31,7 @@ func OneTimeAccessToken(args []string) error {
|
|||||||
err := db.Transaction(func(tx *gorm.DB) error {
|
err := db.Transaction(func(tx *gorm.DB) error {
|
||||||
// Load the user to retrieve the user ID
|
// Load the user to retrieve the user ID
|
||||||
var user model.User
|
var user model.User
|
||||||
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
|
queryCtx, queryCancel := context.WithTimeout(cmd.Context(), 10*time.Second)
|
||||||
defer queryCancel()
|
defer queryCancel()
|
||||||
txErr := tx.
|
txErr := tx.
|
||||||
WithContext(queryCtx).
|
WithContext(queryCtx).
|
||||||
@@ -58,7 +53,7 @@ func OneTimeAccessToken(args []string) error {
|
|||||||
return fmt.Errorf("failed to generate access token: %w", txErr)
|
return fmt.Errorf("failed to generate access token: %w", txErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
queryCtx, queryCancel = context.WithTimeout(ctx, 10*time.Second)
|
queryCtx, queryCancel = context.WithTimeout(cmd.Context(), 10*time.Second)
|
||||||
defer queryCancel()
|
defer queryCancel()
|
||||||
txErr = tx.
|
txErr = tx.
|
||||||
WithContext(queryCtx).
|
WithContext(queryCtx).
|
||||||
@@ -79,4 +74,9 @@ func OneTimeAccessToken(args []string) error {
|
|||||||
fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token)
|
fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(oneTimeAccessTokenCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
31
backend/internal/cmds/root.go
Normal file
31
backend/internal/cmds/root.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Use: "pocket-id",
|
||||||
|
Short: "A simple and easy-to-use OIDC provider that allows users to authenticate with their passkeys to your services.",
|
||||||
|
Long: "By default, this command starts the pocket-id server.",
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
// Start the server
|
||||||
|
err := bootstrap.Bootstrap()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to run pocket-id", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func Execute() {
|
||||||
|
err := rootCmd.Execute()
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
19
backend/internal/cmds/version.go
Normal file
19
backend/internal/cmds/version.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(&cobra.Command{
|
||||||
|
Use: "version",
|
||||||
|
Short: "Print the version number",
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
fmt.Println("pocket-id " + common.Version)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
24
backend/internal/utils/cmd_util.go
Normal file
24
backend/internal/utils/cmd_util.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PromptForConfirmation prompts the user to answer "y" in the terminal
|
||||||
|
func PromptForConfirmation(prompt string) (bool, error) {
|
||||||
|
fmt.Print(prompt + " [y/N]: ")
|
||||||
|
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
r, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
r = strings.TrimSpace(strings.ToLower(r))
|
||||||
|
|
||||||
|
ok := r == "yes" || r == "y"
|
||||||
|
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user