From 8c8fc2304d8f33c1fea54b1138b109f282e78b8b Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Thu, 3 Jul 2025 13:23:24 -0700 Subject: [PATCH] feat: add "key-rotate" command (#709) --- backend/cmd/main.go | 30 +-- backend/go.mod | 3 + backend/go.sum | 8 + backend/internal/cmds/key_rotate.go | 107 +++++++++ backend/internal/cmds/key_rotate_test.go | 214 ++++++++++++++++++ .../internal/cmds/one_time_access_token.go | 116 +++++----- backend/internal/cmds/root.go | 31 +++ backend/internal/cmds/version.go | 19 ++ backend/internal/utils/cmd_util.go | 24 ++ 9 files changed, 465 insertions(+), 87 deletions(-) create mode 100644 backend/internal/cmds/key_rotate.go create mode 100644 backend/internal/cmds/key_rotate_test.go create mode 100644 backend/internal/cmds/root.go create mode 100644 backend/internal/cmds/version.go create mode 100644 backend/internal/utils/cmd_util.go diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 56b746c4..32204aca 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -1,15 +1,9 @@ package main import ( - "flag" - "fmt" - "log" - _ "time/tzdata" - "github.com/pocket-id/pocket-id/backend/internal/bootstrap" "github.com/pocket-id/pocket-id/backend/internal/cmds" - "github.com/pocket-id/pocket-id/backend/internal/common" ) // @title Pocket ID API @@ -17,27 +11,5 @@ import ( // @description.markdown func main() { - // Get the command - // By default, this starts the server - var cmd string - flag.Parse() - args := flag.Args() - if len(args) > 0 { - cmd = args[0] - } - - var err error - switch cmd { - case "version": - fmt.Println("pocket-id " + common.Version) - case "one-time-access-token": - err = cmds.OneTimeAccessToken(args) - default: - // Start the server - err = bootstrap.Bootstrap() - } - - if err != nil { - log.Fatal(err.Error()) - } + cmds.Execute() } diff --git a/backend/go.mod b/backend/go.mod index 339dbde3..ca31ebc8 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -25,6 +25,7 @@ require ( github.com/lestrrat-go/jwx/v3 v3.0.1 github.com/mileusna/useragent v1.3.5 github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2 + github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 go.opentelemetry.io/contrib/exporters/autoexport v0.59.0 go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.60.0 @@ -69,6 +70,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.2 // indirect @@ -99,6 +101,7 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/segmentio/asm v1.2.0 // indirect + github.com/spf13/pflag v1.0.6 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 9fccec08..0f86f8f3 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -24,6 +24,7 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -120,6 +121,8 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9 github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -227,8 +230,13 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/backend/internal/cmds/key_rotate.go b/backend/internal/cmds/key_rotate.go new file mode 100644 index 00000000..5e29f7d3 --- /dev/null +++ b/backend/internal/cmds/key_rotate.go @@ -0,0 +1,107 @@ +package cmds + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/spf13/cobra" + "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/bootstrap" + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/service" + "github.com/pocket-id/pocket-id/backend/internal/utils" + jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" +) + +type keyRotateFlags struct { + Alg string + Crv string + Yes bool +} + +func init() { + var flags keyRotateFlags + + keyRotateCmd := &cobra.Command{ + Use: "key-rotate", + Short: "Generates a new token signing key and replaces the current one", + RunE: func(cmd *cobra.Command, args []string) error { + db := bootstrap.NewDatabase() + + return keyRotate(cmd.Context(), flags, db, &common.EnvConfig) + }, + } + + keyRotateCmd.Flags().StringVarP(&flags.Alg, "alg", "a", "RS256", "Key algorithm. Supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA") + keyRotateCmd.Flags().StringVarP(&flags.Crv, "crv", "c", "", "Curve name when using EdDSA keys. Supported values: Ed25519") + keyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation") + + rootCmd.AddCommand(keyRotateCmd) +} + +func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error { + // Validate the flags + switch strings.ToUpper(flags.Alg) { + case jwa.RS256().String(), jwa.RS384().String(), jwa.RS512().String(), + jwa.ES256().String(), jwa.ES384().String(), jwa.ES512().String(): + // All good, but uppercase it for consistency + flags.Alg = strings.ToUpper(flags.Alg) + case strings.ToUpper(jwa.EdDSA().String()): + // Ensure Crv is set and valid + switch strings.ToUpper(flags.Crv) { + case strings.ToUpper(jwa.Ed25519().String()): + // All good, but ensure consistency in casing + flags.Crv = jwa.Ed25519().String() + case "": + return errors.New("a curve name is required when algorithm is EdDSA") + default: + return errors.New("unsupported EdDSA curve; supported values: Ed25519") + } + case "": + return errors.New("key algorithm is required") + default: + return errors.New("unsupported key algorithm; supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA") + } + + if !flags.Yes { + fmt.Println("WARNING: Rotating the private key will invalidate all existing tokens. Both pocket-id and all client applications will likely need to be restarted.") + ok, err := utils.PromptForConfirmation("Confirm") + if err != nil { + return err + } + if !ok { + fmt.Println("Aborted") + return nil + } + } + + // Init the services we need + appConfigService := service.NewAppConfigService(ctx, db) + + // Get the key provider + keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfigService.GetDbConfig().InstanceID.Value) + if err != nil { + return fmt.Errorf("failed to get key provider: %w", err) + } + + // Generate a new key + key, err := jwkutils.GenerateKey(flags.Alg, flags.Crv) + if err != nil { + return fmt.Errorf("failed to generate key: %w", err) + } + + // Save the key + err = keyProvider.SaveKey(key) + if err != nil { + return fmt.Errorf("failed to store new key: %w", err) + } + + fmt.Println("Key rotated successfully") + fmt.Println("Note: if pocket-id is running, you will need to restart it for the new key to be loaded") + + return nil +} diff --git a/backend/internal/cmds/key_rotate_test.go b/backend/internal/cmds/key_rotate_test.go new file mode 100644 index 00000000..cd020a34 --- /dev/null +++ b/backend/internal/cmds/key_rotate_test.go @@ -0,0 +1,214 @@ +package cmds + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/service" + jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" + testingutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" +) + +func TestKeyRotate(t *testing.T) { + tests := []struct { + name string + flags keyRotateFlags + wantErr bool + errMsg string + }{ + { + name: "valid RS256", + flags: keyRotateFlags{ + Alg: "RS256", + Yes: true, + }, + wantErr: false, + }, + { + name: "valid EdDSA with Ed25519", + flags: keyRotateFlags{ + Alg: "EdDSA", + Crv: "Ed25519", + Yes: true, + }, + wantErr: false, + }, + { + name: "invalid algorithm", + flags: keyRotateFlags{ + Alg: "INVALID", + Yes: true, + }, + wantErr: true, + errMsg: "unsupported key algorithm", + }, + { + name: "EdDSA without curve", + flags: keyRotateFlags{ + Alg: "EdDSA", + Yes: true, + }, + wantErr: true, + errMsg: "a curve name is required when algorithm is EdDSA", + }, + { + name: "empty algorithm", + flags: keyRotateFlags{ + Alg: "", + Yes: true, + }, + wantErr: true, + errMsg: "key algorithm is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Run("file storage", func(t *testing.T) { + testKeyRotateWithFileStorage(t, tt.flags, tt.wantErr, tt.errMsg) + }) + + t.Run("database storage", func(t *testing.T) { + testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg) + }) + }) + } +} + +func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) { + // Create temporary directory for keys + tempDir := t.TempDir() + keysPath := filepath.Join(tempDir, "keys") + err := os.MkdirAll(keysPath, 0755) + require.NoError(t, err) + + // Set up file storage config + envConfig := &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: keysPath, + } + + // Create test database + db := testingutils.NewDatabaseForTest(t) + + // Initialize app config service and create instance + appConfigService := service.NewAppConfigService(t.Context(), db) + instanceID := appConfigService.GetDbConfig().InstanceID.Value + + // Check if key exists before rotation + keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID) + require.NoError(t, err) + + // Run the key rotation + err = keyRotate(t.Context(), flags, db, envConfig) + + if wantErr { + require.Error(t, err) + if errMsg != "" { + require.ErrorContains(t, err, errMsg) + } + return + } + + require.NoError(t, err) + + // Verify key was created + key, err := keyProvider.LoadKey() + require.NoError(t, err) + require.NotNil(t, key) + + // Verify the algorithm matches what we requested + alg, _ := key.Algorithm() + assert.NotEmpty(t, alg) + if flags.Alg != "" { + expectedAlg := flags.Alg + if expectedAlg == "EdDSA" { + // EdDSA keys should have the EdDSA algorithm + assert.Equal(t, "EdDSA", alg.String()) + } else { + assert.Equal(t, expectedAlg, alg.String()) + } + } +} + +func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) { + // Set up database storage config + envConfig := &common.EnvConfigSchema{ + KeysStorage: "database", + EncryptionKey: "test-encryption-key-characters-long", + } + + // Create test database + db := testingutils.NewDatabaseForTest(t) + + // Initialize app config service and create instance + appConfigService := service.NewAppConfigService(t.Context(), db) + instanceID := appConfigService.GetDbConfig().InstanceID.Value + + // Get key provider + keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID) + require.NoError(t, err) + + // Run the key rotation + err = keyRotate(t.Context(), flags, db, envConfig) + + if wantErr { + require.Error(t, err) + if errMsg != "" { + require.ErrorContains(t, err, errMsg) + } + return + } + + require.NoError(t, err) + + // Verify key was created + key, err := keyProvider.LoadKey() + require.NoError(t, err) + require.NotNil(t, key) + + // Verify the algorithm matches what we requested + alg, _ := key.Algorithm() + assert.NotEmpty(t, alg) + if flags.Alg != "" { + expectedAlg := flags.Alg + if expectedAlg == "EdDSA" { + // EdDSA keys should have the EdDSA algorithm + assert.Equal(t, "EdDSA", alg.String()) + } else { + assert.Equal(t, expectedAlg, alg.String()) + } + } +} + +func TestKeyRotateMultipleAlgorithms(t *testing.T) { + algorithms := []struct { + alg string + crv string + }{ + {"RS256", ""}, + {"RS384", ""}, + // Skip RSA-4096 key generation test as it can take a long time + // {"RS512", ""}, + {"ES256", ""}, + {"ES384", ""}, + {"ES512", ""}, + {"EdDSA", "Ed25519"}, + } + + for _, algo := range algorithms { + t.Run(algo.alg, func(t *testing.T) { + // Test with database storage for all algorithms + testKeyRotateWithDatabaseStorage(t, keyRotateFlags{ + Alg: algo.alg, + Crv: algo.crv, + Yes: true, + }, false, "") + }) + } +} diff --git a/backend/internal/cmds/one_time_access_token.go b/backend/internal/cmds/one_time_access_token.go index 1b3f242b..65b94e87 100644 --- a/backend/internal/cmds/one_time_access_token.go +++ b/backend/internal/cmds/one_time_access_token.go @@ -6,77 +6,77 @@ import ( "fmt" "time" + "github.com/spf13/cobra" "gorm.io/gorm" "github.com/pocket-id/pocket-id/backend/internal/bootstrap" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/service" - "github.com/pocket-id/pocket-id/backend/internal/utils/signals" ) -// OneTimeAccessToken creates a one-time access token for the given user -// Args must contain the username or email of the user -func OneTimeAccessToken(args []string) error { - // Get a context that is canceled when the application is stopping - ctx := signals.SignalContext(context.Background()) +var oneTimeAccessTokenCmd = &cobra.Command{ + Use: "one-time-access-token [username or email]", + Short: "Generates a one-time access token for the given user", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + // Get the username or email of the user + userArg := args[0] - // Get the username or email of the user - // Note length is 2 because the first argument is always the command (one-time-access-token) - if len(args) != 2 { - return errors.New("missing username or email of user; usage: one-time-access-token ") - } - userArg := args[1] + // Connect to the database + db := bootstrap.NewDatabase() - // Connect to the database - db := bootstrap.NewDatabase() + // Create the access token + var oneTimeAccessToken *model.OneTimeAccessToken + err := db.Transaction(func(tx *gorm.DB) error { + // Load the user to retrieve the user ID + var user model.User + queryCtx, queryCancel := context.WithTimeout(cmd.Context(), 10*time.Second) + defer queryCancel() + txErr := tx. + WithContext(queryCtx). + Where("username = ? OR email = ?", userArg, userArg). + First(&user). + Error + switch { + case errors.Is(txErr, gorm.ErrRecordNotFound): + return errors.New("user not found") + case txErr != nil: + return fmt.Errorf("failed to query for user: %w", txErr) + case user.ID == "": + return errors.New("invalid user loaded: ID is empty") + } - // Create the access token - var oneTimeAccessToken *model.OneTimeAccessToken - err := db.Transaction(func(tx *gorm.DB) error { - // Load the user to retrieve the user ID - var user model.User - queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second) - defer queryCancel() - txErr := tx. - WithContext(queryCtx). - Where("username = ? OR email = ?", userArg, userArg). - First(&user). - Error - switch { - case errors.Is(txErr, gorm.ErrRecordNotFound): - return errors.New("user not found") - case txErr != nil: - return fmt.Errorf("failed to query for user: %w", txErr) - case user.ID == "": - return errors.New("invalid user loaded: ID is empty") + // Create a new access token that expires in 1 hour + oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour)) + if txErr != nil { + return fmt.Errorf("failed to generate access token: %w", txErr) + } + + queryCtx, queryCancel = context.WithTimeout(cmd.Context(), 10*time.Second) + defer queryCancel() + txErr = tx. + WithContext(queryCtx). + Create(oneTimeAccessToken). + Error + if txErr != nil { + return fmt.Errorf("failed to save access token: %w", txErr) + } + + return nil + }) + if err != nil { + return err } - // Create a new access token that expires in 1 hour - oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour)) - if txErr != nil { - return fmt.Errorf("failed to generate access token: %w", txErr) - } - - queryCtx, queryCancel = context.WithTimeout(ctx, 10*time.Second) - defer queryCancel() - txErr = tx. - WithContext(queryCtx). - Create(oneTimeAccessToken). - Error - if txErr != nil { - return fmt.Errorf("failed to save access token: %w", txErr) - } + // Print the result + fmt.Printf(`A one-time access token valid for 1 hour has been created for "%s".`+"\n", userArg) + fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token) return nil - }) - if err != nil { - return err - } - - // Print the result - fmt.Printf(`A one-time access token valid for 1 hour has been created for "%s".`+"\n", userArg) - fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token) - - return nil + }, +} + +func init() { + rootCmd.AddCommand(oneTimeAccessTokenCmd) } diff --git a/backend/internal/cmds/root.go b/backend/internal/cmds/root.go new file mode 100644 index 00000000..d6f6bb5d --- /dev/null +++ b/backend/internal/cmds/root.go @@ -0,0 +1,31 @@ +package cmds + +import ( + "log/slog" + "os" + + "github.com/spf13/cobra" + + "github.com/pocket-id/pocket-id/backend/internal/bootstrap" +) + +var rootCmd = &cobra.Command{ + Use: "pocket-id", + Short: "A simple and easy-to-use OIDC provider that allows users to authenticate with their passkeys to your services.", + Long: "By default, this command starts the pocket-id server.", + Run: func(cmd *cobra.Command, args []string) { + // Start the server + err := bootstrap.Bootstrap() + if err != nil { + slog.Error("Failed to run pocket-id", "error", err) + os.Exit(1) + } + }, +} + +func Execute() { + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} diff --git a/backend/internal/cmds/version.go b/backend/internal/cmds/version.go new file mode 100644 index 00000000..daf34056 --- /dev/null +++ b/backend/internal/cmds/version.go @@ -0,0 +1,19 @@ +package cmds + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/pocket-id/pocket-id/backend/internal/common" +) + +func init() { + rootCmd.AddCommand(&cobra.Command{ + Use: "version", + Short: "Print the version number", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("pocket-id " + common.Version) + }, + }) +} diff --git a/backend/internal/utils/cmd_util.go b/backend/internal/utils/cmd_util.go new file mode 100644 index 00000000..667ee4d5 --- /dev/null +++ b/backend/internal/utils/cmd_util.go @@ -0,0 +1,24 @@ +package utils + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +// PromptForConfirmation prompts the user to answer "y" in the terminal +func PromptForConfirmation(prompt string) (bool, error) { + fmt.Print(prompt + " [y/N]: ") + + reader := bufio.NewReader(os.Stdin) + r, err := reader.ReadString('\n') + if err != nil { + return false, fmt.Errorf("failed to read response: %w", err) + } + r = strings.TrimSpace(strings.ToLower(r)) + + ok := r == "yes" || r == "y" + + return ok, nil +}