diff --git a/backend/internal/cmds/one_time_access_token.go b/backend/internal/cmds/one_time_access_token.go index 6b9804a5..736ae733 100644 --- a/backend/internal/cmds/one_time_access_token.go +++ b/backend/internal/cmds/one_time_access_token.go @@ -51,7 +51,7 @@ var oneTimeAccessTokenCmd = &cobra.Command{ } // Create a new access token that expires in 1 hour - oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour)) + oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Hour) if txErr != nil { return fmt.Errorf("failed to generate access token: %w", txErr) } diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index 3ab88abb..83eb6525 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -14,6 +14,11 @@ import ( "golang.org/x/time/rate" ) +const ( + defaultOneTimeAccessTokenDuration = 15 * time.Minute + defaultSignupTokenDuration = time.Hour +) + // NewUserController creates a new controller for user management endpoints // @Summary User management controller // @Description Initializes all user-related API endpoints @@ -331,10 +336,17 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bo return } + var ttl time.Duration if own { input.UserID = c.GetString("userID") + ttl = defaultOneTimeAccessTokenDuration + } else { + ttl = input.TTL.Duration + if ttl <= 0 { + ttl = defaultOneTimeAccessTokenDuration + } } - token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, input.ExpiresAt) + token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, ttl) if err != nil { _ = c.Error(err) return @@ -411,7 +423,11 @@ func (uc *UserController) RequestOneTimeAccessEmailAsAdminHandler(c *gin.Context userID := c.Param("id") - err := uc.userService.RequestOneTimeAccessEmailAsAdmin(c.Request.Context(), userID, input.ExpiresAt) + ttl := input.TTL.Duration + if ttl <= 0 { + ttl = defaultOneTimeAccessTokenDuration + } + err := uc.userService.RequestOneTimeAccessEmailAsAdmin(c.Request.Context(), userID, ttl) if err != nil { _ = c.Error(err) return @@ -526,14 +542,20 @@ func (uc *UserController) createSignupTokenHandler(c *gin.Context) { return } - signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), input.ExpiresAt, input.UsageLimit) + ttl := input.TTL.Duration + if ttl <= 0 { + ttl = defaultSignupTokenDuration + } + + signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), ttl, input.UsageLimit) if err != nil { _ = c.Error(err) return } var tokenDto dto.SignupTokenDto - if err := dto.MapStruct(signupToken, &tokenDto); err != nil { + err = dto.MapStruct(signupToken, &tokenDto) + if err != nil { _ = c.Error(err) return } diff --git a/backend/internal/dto/signup_token_dto.go b/backend/internal/dto/signup_token_dto.go index a1d9ca89..92bb374a 100644 --- a/backend/internal/dto/signup_token_dto.go +++ b/backend/internal/dto/signup_token_dto.go @@ -1,14 +1,13 @@ package dto import ( - "time" - datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" + "github.com/pocket-id/pocket-id/backend/internal/utils" ) type SignupTokenCreateDto struct { - ExpiresAt time.Time `json:"expiresAt" binding:"required"` - UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"` + TTL utils.JSONDuration `json:"ttl" binding:"required,ttl"` + UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"` } type SignupTokenDto struct { diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go index 55fff4ce..401b15b8 100644 --- a/backend/internal/dto/user_dto.go +++ b/backend/internal/dto/user_dto.go @@ -1,7 +1,7 @@ package dto import ( - "time" + "github.com/pocket-id/pocket-id/backend/internal/utils" ) type UserDto struct { @@ -30,8 +30,8 @@ type UserCreateDto struct { } type OneTimeAccessTokenCreateDto struct { - UserID string `json:"userId"` - ExpiresAt time.Time `json:"expiresAt" binding:"required"` + UserID string `json:"userId"` + TTL utils.JSONDuration `json:"ttl" binding:"ttl"` } type OneTimeAccessEmailAsUnauthenticatedUserDto struct { @@ -40,7 +40,7 @@ type OneTimeAccessEmailAsUnauthenticatedUserDto struct { } type OneTimeAccessEmailAsAdminDto struct { - ExpiresAt time.Time `json:"expiresAt" binding:"required"` + TTL utils.JSONDuration `json:"ttl" binding:"ttl"` } type UserUpdateUserGroupDto struct { diff --git a/backend/internal/dto/validations.go b/backend/internal/dto/validations.go index b98ac27d..4db3695d 100644 --- a/backend/internal/dto/validations.go +++ b/backend/internal/dto/validations.go @@ -1,29 +1,42 @@ package dto import ( - "log/slog" - "os" "regexp" + "time" + + "github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" ) -// [a-zA-Z0-9] : The username must start with an alphanumeric character -// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols -// [a-zA-Z0-9]$ : The username must end with an alphanumeric character -var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$") - -var validateUsername validator.Func = func(fl validator.FieldLevel) bool { - return validateUsernameRegex.MatchString(fl.Field().String()) -} - func init() { - v, _ := binding.Validator.Engine().(*validator.Validate) - err := v.RegisterValidation("username", validateUsername) + v := binding.Validator.Engine().(*validator.Validate) + + // [a-zA-Z0-9] : The username must start with an alphanumeric character + // [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols + // [a-zA-Z0-9]$ : The username must end with an alphanumeric character + var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$") + + // Maximum allowed value for TTLs + const maxTTL = 31 * 24 * time.Hour + + // Errors here are development-time ones + err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool { + return validateUsernameRegex.MatchString(fl.Field().String()) + }) if err != nil { - slog.Error("Failed to register custom validation", slog.Any("error", err)) - os.Exit(1) - return + panic("Failed to register custom validation for username: " + err.Error()) + } + err = v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool { + ttl, ok := fl.Field().Interface().(utils.JSONDuration) + if !ok { + return false + } + // Allow zero, which means the field wasn't set + return ttl.Duration == 0 || ttl.Duration > time.Second && ttl.Duration <= maxTTL + }) + if err != nil { + panic("Failed to register custom validation for ttl: " + err.Error()) } } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 1ed1795c..b791e1f1 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -348,13 +348,13 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd return user, nil } -func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, userID string, expiration time.Time) error { +func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, userID string, ttl time.Duration) error { isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsAdminEnabled.IsTrue() if isDisabled { return &common.OneTimeAccessDisabledError{} } - return s.requestOneTimeAccessEmailInternal(ctx, userID, "", expiration) + return s.requestOneTimeAccessEmailInternal(ctx, userID, "", ttl) } func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error { @@ -374,11 +374,10 @@ func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context } } - expiration := time.Now().Add(15 * time.Minute) - return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, expiration) + return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute) } -func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, expiration time.Time) error { +func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, ttl time.Duration) error { tx := s.db.Begin() defer func() { tx.Rollback() @@ -389,7 +388,7 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use return err } - oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, expiration, tx) + oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, tx) if err != nil { return err } @@ -421,7 +420,7 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use Code: oneTimeAccessToken, LoginLink: link, LoginLinkWithCode: linkWithCode, - ExpirationString: utils.DurationToString(time.Until(expiration).Round(time.Second)), + ExpirationString: utils.DurationToString(ttl), }) if errInternal != nil { slog.ErrorContext(innerCtx, "Failed to send one-time access token email", slog.Any("error", errInternal), slog.String("address", user.Email)) @@ -432,17 +431,18 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use return nil } -func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) { - return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db) +func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, ttl time.Duration) (string, error) { + return s.createOneTimeAccessTokenInternal(ctx, userID, ttl, s.db) } -func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) { - oneTimeAccessToken, err := NewOneTimeAccessToken(userID, expiresAt) +func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, ttl time.Duration, tx *gorm.DB) (string, error) { + oneTimeAccessToken, err := NewOneTimeAccessToken(userID, ttl) if err != nil { return "", err } - if err := tx.WithContext(ctx).Create(oneTimeAccessToken).Error; err != nil { + err = tx.WithContext(ctx).Create(oneTimeAccessToken).Error + if err != nil { return "", err } @@ -642,17 +642,14 @@ func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx Error } -func (s *UserService) CreateSignupToken(ctx context.Context, expiresAt time.Time, usageLimit int) (model.SignupToken, error) { - return s.createSignupTokenInternal(ctx, expiresAt, usageLimit, s.db) -} - -func (s *UserService) createSignupTokenInternal(ctx context.Context, expiresAt time.Time, usageLimit int, tx *gorm.DB) (model.SignupToken, error) { - signupToken, err := NewSignupToken(expiresAt, usageLimit) +func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int) (model.SignupToken, error) { + signupToken, err := NewSignupToken(ttl, usageLimit) if err != nil { return model.SignupToken{}, err } - if err := tx.WithContext(ctx).Create(signupToken).Error; err != nil { + err = s.db.WithContext(ctx).Create(signupToken).Error + if err != nil { return model.SignupToken{}, err } @@ -746,10 +743,10 @@ func (s *UserService) DeleteSignupToken(ctx context.Context, tokenID string) err return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error } -func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAccessToken, error) { +func NewOneTimeAccessToken(userID string, ttl time.Duration) (*model.OneTimeAccessToken, error) { // If expires at is less than 15 minutes, use a 6-character token instead of 16 tokenLength := 16 - if time.Until(expiresAt) <= 15*time.Minute { + if ttl <= 15*time.Minute { tokenLength = 6 } @@ -758,25 +755,27 @@ func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAc return nil, err } + now := time.Now().Round(time.Second) o := &model.OneTimeAccessToken{ UserID: userID, - ExpiresAt: datatype.DateTime(expiresAt), + ExpiresAt: datatype.DateTime(now.Add(ttl)), Token: randomString, } return o, nil } -func NewSignupToken(expiresAt time.Time, usageLimit int) (*model.SignupToken, error) { +func NewSignupToken(ttl time.Duration, usageLimit int) (*model.SignupToken, error) { // Generate a random token randomString, err := utils.GenerateRandomAlphanumericString(16) if err != nil { return nil, err } + now := time.Now().Round(time.Second) token := &model.SignupToken{ Token: randomString, - ExpiresAt: datatype.DateTime(expiresAt), + ExpiresAt: datatype.DateTime(now.Add(ttl)), UsageLimit: usageLimit, UsageCount: 0, } diff --git a/backend/internal/utils/json_util.go b/backend/internal/utils/json_util.go new file mode 100644 index 00000000..ddbf423c --- /dev/null +++ b/backend/internal/utils/json_util.go @@ -0,0 +1,42 @@ +package utils + +import ( + "encoding/json" + "errors" + "time" +) + +// JSONDuration is a type that allows marshalling/unmarshalling a Duration +type JSONDuration struct { + time.Duration +} + +func (d JSONDuration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +func (d *JSONDuration) UnmarshalJSON(b []byte) error { + var v any + err := json.Unmarshal(b, &v) + if err != nil { + return err + } + switch value := v.(type) { + case float64: + // If the value is a number, interpret it as a number of seconds + d.Duration = time.Duration(value) * time.Second + return nil + case string: + if v == "" { + return nil + } + var err error + d.Duration, err = time.ParseDuration(value) + if err != nil { + return err + } + return nil + default: + return errors.New("invalid duration") + } +} diff --git a/backend/internal/utils/json_util_test.go b/backend/internal/utils/json_util_test.go new file mode 100644 index 00000000..923d1614 --- /dev/null +++ b/backend/internal/utils/json_util_test.go @@ -0,0 +1,64 @@ +package utils + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONDuration_MarshalJSON(t *testing.T) { + tests := []struct { + duration time.Duration + want string + }{ + {time.Minute + 30*time.Second, "1m30s"}, + {0, "0s"}, + } + for _, tc := range tests { + d := JSONDuration{Duration: tc.duration} + b, err := json.Marshal(d) + require.NoError(t, err) + assert.Equal(t, `"`+tc.want+`"`, string(b)) + } +} + +func TestJSONDuration_UnmarshalJSON_String(t *testing.T) { + var d JSONDuration + err := json.Unmarshal([]byte(`"2h15m5s"`), &d) + require.NoError(t, err) + want := 2*time.Hour + 15*time.Minute + 5*time.Second + assert.Equal(t, want, d.Duration) +} + +func TestJSONDuration_UnmarshalJSON_NumberSeconds(t *testing.T) { + tests := []struct { + json string + want time.Duration + }{ + {"0", 0}, + {"1", 1 * time.Second}, + {"2.25", 2 * time.Second}, // Milliseconds are truncated + } + for _, tc := range tests { + var d JSONDuration + err := json.Unmarshal([]byte(tc.json), &d) + require.NoError(t, err, "input: %s", tc.json) + assert.Equal(t, tc.want, d.Duration, "input: %s", tc.json) + } +} + +func TestJSONDuration_UnmarshalJSON_Invalid(t *testing.T) { + cases := [][]byte{ + []byte(`true`), + []byte(`{}`), + []byte(`"not-a-duration"`), + } + for _, b := range cases { + var d JSONDuration + err := json.Unmarshal(b, &d) + require.Error(t, err, "input: %s", string(b)) + } +} diff --git a/frontend/src/lib/components/one-time-link-modal.svelte b/frontend/src/lib/components/one-time-link-modal.svelte index d45805c9..4c201fc1 100644 --- a/frontend/src/lib/components/one-time-link-modal.svelte +++ b/frontend/src/lib/components/one-time-link-modal.svelte @@ -36,8 +36,7 @@ async function createLoginCode() { try { - const expiration = new Date(Date.now() + availableExpirations[selectedExpiration] * 1000); - code = await userService.createOneTimeAccessToken(expiration, userId!); + code = await userService.createOneTimeAccessToken(userId!, availableExpirations[selectedExpiration]); oneTimeLink = `${page.url.origin}/lc/${code}`; } catch (e) { axiosErrorToast(e); @@ -46,8 +45,7 @@ async function sendLoginCodeEmail() { try { - const expiration = new Date(Date.now() + availableExpirations[selectedExpiration] * 1000); - await userService.requestOneTimeAccessEmailAsAdmin(userId!, expiration); + await userService.requestOneTimeAccessEmailAsAdmin(userId!, availableExpirations[selectedExpiration]); toast.success(m.login_code_email_success()); onOpenChange(false); } catch (e) { @@ -81,7 +79,7 @@ value={Object.keys(availableExpirations)[0]} onValueChange={(v) => (selectedExpiration = v! as keyof typeof availableExpirations)} > - + {selectedExpiration} @@ -111,7 +109,7 @@

{code}

-
+

{m.or_visit()}

diff --git a/frontend/src/lib/components/signup/signup-token-modal.svelte b/frontend/src/lib/components/signup/signup-token-modal.svelte index d3799f09..aa3b59fd 100644 --- a/frontend/src/lib/components/signup/signup-token-modal.svelte +++ b/frontend/src/lib/components/signup/signup-token-modal.svelte @@ -37,8 +37,7 @@ async function createSignupToken() { try { - const expiration = new Date(Date.now() + availableExpirations[selectedExpiration] * 1000); - signupToken = await userService.createSignupToken(expiration, usageLimit); + signupToken = await userService.createSignupToken(availableExpirations[selectedExpiration], usageLimit); signupLink = `${page.url.origin}/st/${signupToken}`; if (onTokenCreated) { diff --git a/frontend/src/lib/services/user-service.ts b/frontend/src/lib/services/user-service.ts index 3ac705df..7726ba84 100644 --- a/frontend/src/lib/services/user-service.ts +++ b/frontend/src/lib/services/user-service.ts @@ -75,17 +75,17 @@ export default class UserService extends APIService { cachedProfilePicture.bustCache(userId); } - async createOneTimeAccessToken(expiresAt: Date, userId: string) { + async createOneTimeAccessToken(userId: string = 'me', ttl?: string|number) { const res = await this.api.post(`/users/${userId}/one-time-access-token`, { userId, - expiresAt + ttl, }); return res.data.token; } - async createSignupToken(expiresAt: Date, usageLimit: number) { + async createSignupToken(ttl: string|number, usageLimit: number) { const res = await this.api.post(`/signup-tokens`, { - expiresAt, + ttl, usageLimit }); return res.data.token; @@ -100,8 +100,8 @@ export default class UserService extends APIService { await this.api.post('/one-time-access-email', { email, redirectPath }); } - async requestOneTimeAccessEmailAsAdmin(userId: string, expiresAt: Date) { - await this.api.post(`/users/${userId}/one-time-access-email`, { expiresAt }); + async requestOneTimeAccessEmailAsAdmin(userId: string, ttl: string|number) { + await this.api.post(`/users/${userId}/one-time-access-email`, { ttl }); } async updateUserGroups(id: string, userGroupIds: string[]) { diff --git a/frontend/src/routes/settings/account/login-code-modal.svelte b/frontend/src/routes/settings/account/login-code-modal.svelte index 97239181..a98c004e 100644 --- a/frontend/src/routes/settings/account/login-code-modal.svelte +++ b/frontend/src/routes/settings/account/login-code-modal.svelte @@ -22,9 +22,8 @@ $effect(() => { if (show) { - const expiration = new Date(Date.now() + 15 * 60 * 1000); userService - .createOneTimeAccessToken(expiration, 'me') + .createOneTimeAccessToken('me') .then((c) => { code = c; loginCodeLink = page.url.origin + '/lc/' + code; @@ -54,7 +53,7 @@

{code}

-
+

{m.or_visit()}