fix: for one-time access tokens and signup tokens, pass TTLs instead of absolute expiration date (#855)

This commit is contained in:
Alessandro (Ale) Segala
2025-08-21 23:02:56 -07:00
committed by GitHub
parent 49f0fa423c
commit 7ab0fd3028
12 changed files with 205 additions and 70 deletions

View File

@@ -51,7 +51,7 @@ var oneTimeAccessTokenCmd = &cobra.Command{
} }
// Create a new access token that expires in 1 hour // Create a new access token that expires in 1 hour
oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour)) oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Hour)
if txErr != nil { if txErr != nil {
return fmt.Errorf("failed to generate access token: %w", txErr) return fmt.Errorf("failed to generate access token: %w", txErr)
} }

View File

@@ -14,6 +14,11 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
const (
defaultOneTimeAccessTokenDuration = 15 * time.Minute
defaultSignupTokenDuration = time.Hour
)
// NewUserController creates a new controller for user management endpoints // NewUserController creates a new controller for user management endpoints
// @Summary User management controller // @Summary User management controller
// @Description Initializes all user-related API endpoints // @Description Initializes all user-related API endpoints
@@ -331,10 +336,17 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bo
return return
} }
var ttl time.Duration
if own { if own {
input.UserID = c.GetString("userID") input.UserID = c.GetString("userID")
ttl = defaultOneTimeAccessTokenDuration
} else {
ttl = input.TTL.Duration
if ttl <= 0 {
ttl = defaultOneTimeAccessTokenDuration
} }
token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, input.ExpiresAt) }
token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, ttl)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -411,7 +423,11 @@ func (uc *UserController) RequestOneTimeAccessEmailAsAdminHandler(c *gin.Context
userID := c.Param("id") userID := c.Param("id")
err := uc.userService.RequestOneTimeAccessEmailAsAdmin(c.Request.Context(), userID, input.ExpiresAt) ttl := input.TTL.Duration
if ttl <= 0 {
ttl = defaultOneTimeAccessTokenDuration
}
err := uc.userService.RequestOneTimeAccessEmailAsAdmin(c.Request.Context(), userID, ttl)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -526,14 +542,20 @@ func (uc *UserController) createSignupTokenHandler(c *gin.Context) {
return return
} }
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), input.ExpiresAt, input.UsageLimit) ttl := input.TTL.Duration
if ttl <= 0 {
ttl = defaultSignupTokenDuration
}
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), ttl, input.UsageLimit)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }
var tokenDto dto.SignupTokenDto var tokenDto dto.SignupTokenDto
if err := dto.MapStruct(signupToken, &tokenDto); err != nil { err = dto.MapStruct(signupToken, &tokenDto)
if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -1,13 +1,12 @@
package dto package dto
import ( import (
"time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
) )
type SignupTokenCreateDto struct { type SignupTokenCreateDto struct {
ExpiresAt time.Time `json:"expiresAt" binding:"required"` TTL utils.JSONDuration `json:"ttl" binding:"required,ttl"`
UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"` UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"`
} }

View File

@@ -1,7 +1,7 @@
package dto package dto
import ( import (
"time" "github.com/pocket-id/pocket-id/backend/internal/utils"
) )
type UserDto struct { type UserDto struct {
@@ -31,7 +31,7 @@ type UserCreateDto struct {
type OneTimeAccessTokenCreateDto struct { type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId"` UserID string `json:"userId"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"` TTL utils.JSONDuration `json:"ttl" binding:"ttl"`
} }
type OneTimeAccessEmailAsUnauthenticatedUserDto struct { type OneTimeAccessEmailAsUnauthenticatedUserDto struct {
@@ -40,7 +40,7 @@ type OneTimeAccessEmailAsUnauthenticatedUserDto struct {
} }
type OneTimeAccessEmailAsAdminDto struct { type OneTimeAccessEmailAsAdminDto struct {
ExpiresAt time.Time `json:"expiresAt" binding:"required"` TTL utils.JSONDuration `json:"ttl" binding:"ttl"`
} }
type UserUpdateUserGroupDto struct { type UserUpdateUserGroupDto struct {

View File

@@ -1,29 +1,42 @@
package dto package dto
import ( import (
"log/slog"
"os"
"regexp" "regexp"
"time"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
) )
// [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
return validateUsernameRegex.MatchString(fl.Field().String())
}
func init() { func init() {
v, _ := binding.Validator.Engine().(*validator.Validate) v := binding.Validator.Engine().(*validator.Validate)
err := v.RegisterValidation("username", validateUsername)
// [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
// Maximum allowed value for TTLs
const maxTTL = 31 * 24 * time.Hour
// Errors here are development-time ones
err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
return validateUsernameRegex.MatchString(fl.Field().String())
})
if err != nil { if err != nil {
slog.Error("Failed to register custom validation", slog.Any("error", err)) panic("Failed to register custom validation for username: " + err.Error())
os.Exit(1) }
return err = v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool {
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
if !ok {
return false
}
// Allow zero, which means the field wasn't set
return ttl.Duration == 0 || ttl.Duration > time.Second && ttl.Duration <= maxTTL
})
if err != nil {
panic("Failed to register custom validation for ttl: " + err.Error())
} }
} }

View File

@@ -348,13 +348,13 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
return user, nil return user, nil
} }
func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, userID string, expiration time.Time) error { func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, userID string, ttl time.Duration) error {
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsAdminEnabled.IsTrue() isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsAdminEnabled.IsTrue()
if isDisabled { if isDisabled {
return &common.OneTimeAccessDisabledError{} return &common.OneTimeAccessDisabledError{}
} }
return s.requestOneTimeAccessEmailInternal(ctx, userID, "", expiration) return s.requestOneTimeAccessEmailInternal(ctx, userID, "", ttl)
} }
func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error { func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error {
@@ -374,11 +374,10 @@ func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context
} }
} }
expiration := time.Now().Add(15 * time.Minute) return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, expiration)
} }
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, expiration time.Time) error { func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, ttl time.Duration) error {
tx := s.db.Begin() tx := s.db.Begin()
defer func() { defer func() {
tx.Rollback() tx.Rollback()
@@ -389,7 +388,7 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
return err return err
} }
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, expiration, tx) oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, tx)
if err != nil { if err != nil {
return err return err
} }
@@ -421,7 +420,7 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
Code: oneTimeAccessToken, Code: oneTimeAccessToken,
LoginLink: link, LoginLink: link,
LoginLinkWithCode: linkWithCode, LoginLinkWithCode: linkWithCode,
ExpirationString: utils.DurationToString(time.Until(expiration).Round(time.Second)), ExpirationString: utils.DurationToString(ttl),
}) })
if errInternal != nil { if errInternal != nil {
slog.ErrorContext(innerCtx, "Failed to send one-time access token email", slog.Any("error", errInternal), slog.String("address", user.Email)) slog.ErrorContext(innerCtx, "Failed to send one-time access token email", slog.Any("error", errInternal), slog.String("address", user.Email))
@@ -432,17 +431,18 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
return nil return nil
} }
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) { func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, ttl time.Duration) (string, error) {
return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db) return s.createOneTimeAccessTokenInternal(ctx, userID, ttl, s.db)
} }
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) { func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, ttl time.Duration, tx *gorm.DB) (string, error) {
oneTimeAccessToken, err := NewOneTimeAccessToken(userID, expiresAt) oneTimeAccessToken, err := NewOneTimeAccessToken(userID, ttl)
if err != nil { if err != nil {
return "", err return "", err
} }
if err := tx.WithContext(ctx).Create(oneTimeAccessToken).Error; err != nil { err = tx.WithContext(ctx).Create(oneTimeAccessToken).Error
if err != nil {
return "", err return "", err
} }
@@ -642,17 +642,14 @@ func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx
Error Error
} }
func (s *UserService) CreateSignupToken(ctx context.Context, expiresAt time.Time, usageLimit int) (model.SignupToken, error) { func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int) (model.SignupToken, error) {
return s.createSignupTokenInternal(ctx, expiresAt, usageLimit, s.db) signupToken, err := NewSignupToken(ttl, usageLimit)
}
func (s *UserService) createSignupTokenInternal(ctx context.Context, expiresAt time.Time, usageLimit int, tx *gorm.DB) (model.SignupToken, error) {
signupToken, err := NewSignupToken(expiresAt, usageLimit)
if err != nil { if err != nil {
return model.SignupToken{}, err return model.SignupToken{}, err
} }
if err := tx.WithContext(ctx).Create(signupToken).Error; err != nil { err = s.db.WithContext(ctx).Create(signupToken).Error
if err != nil {
return model.SignupToken{}, err return model.SignupToken{}, err
} }
@@ -746,10 +743,10 @@ func (s *UserService) DeleteSignupToken(ctx context.Context, tokenID string) err
return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error
} }
func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAccessToken, error) { func NewOneTimeAccessToken(userID string, ttl time.Duration) (*model.OneTimeAccessToken, error) {
// If expires at is less than 15 minutes, use a 6-character token instead of 16 // If expires at is less than 15 minutes, use a 6-character token instead of 16
tokenLength := 16 tokenLength := 16
if time.Until(expiresAt) <= 15*time.Minute { if ttl <= 15*time.Minute {
tokenLength = 6 tokenLength = 6
} }
@@ -758,25 +755,27 @@ func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAc
return nil, err return nil, err
} }
now := time.Now().Round(time.Second)
o := &model.OneTimeAccessToken{ o := &model.OneTimeAccessToken{
UserID: userID, UserID: userID,
ExpiresAt: datatype.DateTime(expiresAt), ExpiresAt: datatype.DateTime(now.Add(ttl)),
Token: randomString, Token: randomString,
} }
return o, nil return o, nil
} }
func NewSignupToken(expiresAt time.Time, usageLimit int) (*model.SignupToken, error) { func NewSignupToken(ttl time.Duration, usageLimit int) (*model.SignupToken, error) {
// Generate a random token // Generate a random token
randomString, err := utils.GenerateRandomAlphanumericString(16) randomString, err := utils.GenerateRandomAlphanumericString(16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
now := time.Now().Round(time.Second)
token := &model.SignupToken{ token := &model.SignupToken{
Token: randomString, Token: randomString,
ExpiresAt: datatype.DateTime(expiresAt), ExpiresAt: datatype.DateTime(now.Add(ttl)),
UsageLimit: usageLimit, UsageLimit: usageLimit,
UsageCount: 0, UsageCount: 0,
} }

View File

@@ -0,0 +1,42 @@
package utils
import (
"encoding/json"
"errors"
"time"
)
// JSONDuration is a type that allows marshalling/unmarshalling a Duration
type JSONDuration struct {
time.Duration
}
func (d JSONDuration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
func (d *JSONDuration) UnmarshalJSON(b []byte) error {
var v any
err := json.Unmarshal(b, &v)
if err != nil {
return err
}
switch value := v.(type) {
case float64:
// If the value is a number, interpret it as a number of seconds
d.Duration = time.Duration(value) * time.Second
return nil
case string:
if v == "" {
return nil
}
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}

View File

@@ -0,0 +1,64 @@
package utils
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestJSONDuration_MarshalJSON(t *testing.T) {
tests := []struct {
duration time.Duration
want string
}{
{time.Minute + 30*time.Second, "1m30s"},
{0, "0s"},
}
for _, tc := range tests {
d := JSONDuration{Duration: tc.duration}
b, err := json.Marshal(d)
require.NoError(t, err)
assert.Equal(t, `"`+tc.want+`"`, string(b))
}
}
func TestJSONDuration_UnmarshalJSON_String(t *testing.T) {
var d JSONDuration
err := json.Unmarshal([]byte(`"2h15m5s"`), &d)
require.NoError(t, err)
want := 2*time.Hour + 15*time.Minute + 5*time.Second
assert.Equal(t, want, d.Duration)
}
func TestJSONDuration_UnmarshalJSON_NumberSeconds(t *testing.T) {
tests := []struct {
json string
want time.Duration
}{
{"0", 0},
{"1", 1 * time.Second},
{"2.25", 2 * time.Second}, // Milliseconds are truncated
}
for _, tc := range tests {
var d JSONDuration
err := json.Unmarshal([]byte(tc.json), &d)
require.NoError(t, err, "input: %s", tc.json)
assert.Equal(t, tc.want, d.Duration, "input: %s", tc.json)
}
}
func TestJSONDuration_UnmarshalJSON_Invalid(t *testing.T) {
cases := [][]byte{
[]byte(`true`),
[]byte(`{}`),
[]byte(`"not-a-duration"`),
}
for _, b := range cases {
var d JSONDuration
err := json.Unmarshal(b, &d)
require.Error(t, err, "input: %s", string(b))
}
}

View File

@@ -36,8 +36,7 @@
async function createLoginCode() { async function createLoginCode() {
try { try {
const expiration = new Date(Date.now() + availableExpirations[selectedExpiration] * 1000); code = await userService.createOneTimeAccessToken(userId!, availableExpirations[selectedExpiration]);
code = await userService.createOneTimeAccessToken(expiration, userId!);
oneTimeLink = `${page.url.origin}/lc/${code}`; oneTimeLink = `${page.url.origin}/lc/${code}`;
} catch (e) { } catch (e) {
axiosErrorToast(e); axiosErrorToast(e);
@@ -46,8 +45,7 @@
async function sendLoginCodeEmail() { async function sendLoginCodeEmail() {
try { try {
const expiration = new Date(Date.now() + availableExpirations[selectedExpiration] * 1000); await userService.requestOneTimeAccessEmailAsAdmin(userId!, availableExpirations[selectedExpiration]);
await userService.requestOneTimeAccessEmailAsAdmin(userId!, expiration);
toast.success(m.login_code_email_success()); toast.success(m.login_code_email_success());
onOpenChange(false); onOpenChange(false);
} catch (e) { } catch (e) {
@@ -81,7 +79,7 @@
value={Object.keys(availableExpirations)[0]} value={Object.keys(availableExpirations)[0]}
onValueChange={(v) => (selectedExpiration = v! as keyof typeof availableExpirations)} onValueChange={(v) => (selectedExpiration = v! as keyof typeof availableExpirations)}
> >
<Select.Trigger id="expiration" class="h-9 w-full"> <Select.Trigger id="expiration" class="w-full h-9">
{selectedExpiration} {selectedExpiration}
</Select.Trigger> </Select.Trigger>
<Select.Content> <Select.Content>
@@ -111,7 +109,7 @@
<p class="text-3xl font-code">{code}</p> <p class="text-3xl font-code">{code}</p>
</CopyToClipboard> </CopyToClipboard>
<div class="text-muted-foreground my-2 flex items-center justify-center gap-3"> <div class="flex items-center justify-center gap-3 my-2 text-muted-foreground">
<Separator /> <Separator />
<p class="text-xs text-nowrap">{m.or_visit()}</p> <p class="text-xs text-nowrap">{m.or_visit()}</p>
<Separator /> <Separator />

View File

@@ -37,8 +37,7 @@
async function createSignupToken() { async function createSignupToken() {
try { try {
const expiration = new Date(Date.now() + availableExpirations[selectedExpiration] * 1000); signupToken = await userService.createSignupToken(availableExpirations[selectedExpiration], usageLimit);
signupToken = await userService.createSignupToken(expiration, usageLimit);
signupLink = `${page.url.origin}/st/${signupToken}`; signupLink = `${page.url.origin}/st/${signupToken}`;
if (onTokenCreated) { if (onTokenCreated) {

View File

@@ -75,17 +75,17 @@ export default class UserService extends APIService {
cachedProfilePicture.bustCache(userId); cachedProfilePicture.bustCache(userId);
} }
async createOneTimeAccessToken(expiresAt: Date, userId: string) { async createOneTimeAccessToken(userId: string = 'me', ttl?: string|number) {
const res = await this.api.post(`/users/${userId}/one-time-access-token`, { const res = await this.api.post(`/users/${userId}/one-time-access-token`, {
userId, userId,
expiresAt ttl,
}); });
return res.data.token; return res.data.token;
} }
async createSignupToken(expiresAt: Date, usageLimit: number) { async createSignupToken(ttl: string|number, usageLimit: number) {
const res = await this.api.post(`/signup-tokens`, { const res = await this.api.post(`/signup-tokens`, {
expiresAt, ttl,
usageLimit usageLimit
}); });
return res.data.token; return res.data.token;
@@ -100,8 +100,8 @@ export default class UserService extends APIService {
await this.api.post('/one-time-access-email', { email, redirectPath }); await this.api.post('/one-time-access-email', { email, redirectPath });
} }
async requestOneTimeAccessEmailAsAdmin(userId: string, expiresAt: Date) { async requestOneTimeAccessEmailAsAdmin(userId: string, ttl: string|number) {
await this.api.post(`/users/${userId}/one-time-access-email`, { expiresAt }); await this.api.post(`/users/${userId}/one-time-access-email`, { ttl });
} }
async updateUserGroups(id: string, userGroupIds: string[]) { async updateUserGroups(id: string, userGroupIds: string[]) {

View File

@@ -22,9 +22,8 @@
$effect(() => { $effect(() => {
if (show) { if (show) {
const expiration = new Date(Date.now() + 15 * 60 * 1000);
userService userService
.createOneTimeAccessToken(expiration, 'me') .createOneTimeAccessToken('me')
.then((c) => { .then((c) => {
code = c; code = c;
loginCodeLink = page.url.origin + '/lc/' + code; loginCodeLink = page.url.origin + '/lc/' + code;
@@ -54,7 +53,7 @@
<CopyToClipboard value={code!}> <CopyToClipboard value={code!}>
<p class="text-3xl font-code">{code}</p> <p class="text-3xl font-code">{code}</p>
</CopyToClipboard> </CopyToClipboard>
<div class="text-muted-foreground my-2 flex items-center justify-center gap-3"> <div class="flex items-center justify-center gap-3 my-2 text-muted-foreground">
<Separator /> <Separator />
<p class="text-xs text-nowrap">{m.or_visit()}</p> <p class="text-xs text-nowrap">{m.or_visit()}</p>
<Separator /> <Separator />