feat: add OIDC refresh_token support (#325)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-03-23 15:14:26 -05:00
committed by GitHub
parent 7888d70656
commit b8dcda8049
14 changed files with 339 additions and 55 deletions

View File

@@ -255,3 +255,33 @@ type APIKeyExpirationDateError struct{}
func (e *APIKeyExpirationDateError) Error() string { func (e *APIKeyExpirationDateError) Error() string {
return "API Key expiration time must be in the future" return "API Key expiration time must be in the future"
} }
type OidcInvalidRefreshTokenError struct{}
func (e *OidcInvalidRefreshTokenError) Error() string {
return "refresh token is invalid or expired"
}
func (e *OidcInvalidRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcMissingRefreshTokenError struct{}
func (e *OidcMissingRefreshTokenError) Error() string {
return "refresh token is required"
}
func (e *OidcMissingRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcMissingAuthorizationCodeError struct{}
func (e *OidcMissingAuthorizationCodeError) Error() string {
return "authorization code is required"
}
func (e *OidcMissingAuthorizationCodeError) HttpStatusCode() int {
return http.StatusBadRequest
}

View File

@@ -111,28 +111,39 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
// createTokensHandler godoc // createTokensHandler godoc
// @Summary Create OIDC tokens // @Summary Create OIDC tokens
// @Description Exchange authorization code for ID and access tokens // @Description Exchange authorization code or refresh token for access tokens
// @Tags OIDC // @Tags OIDC
// @Accept application/x-www-form-urlencoded
// @Produce json // @Produce json
// @Param client_id formData string false "Client ID (if not using Basic Auth)" // @Param client_id formData string false "Client ID (if not using Basic Auth)"
// @Param client_secret formData string false "Client secret (if not using Basic Auth)" // @Param client_secret formData string false "Client secret (if not using Basic Auth)"
// @Param code formData string true "Authorization code" // @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
// @Param grant_type formData string true "Grant type (must be 'authorization_code')" // @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
// @Param code_verifier formData string false "PKCE code verifier" // @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
// @Success 200 {object} object "{ \"id_token\": \"string\", \"access_token\": \"string\", \"token_type\": \"Bearer\" }" // @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
// @Router /api/oidc/token [post] // @Router /api/oidc/token [post]
func (oc *OidcController) createTokensHandler(c *gin.Context) { func (oc *OidcController) createTokensHandler(c *gin.Context) {
// Disable cors for this endpoint // Disable cors for this endpoint
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
var input dto.OidcCreateTokensDto var input dto.OidcCreateTokensDto
if err := c.ShouldBind(&input); err != nil { if err := c.ShouldBind(&input); err != nil {
c.Error(err) c.Error(err)
return return
} }
// Validate that code is provided for authorization_code grant type
if input.GrantType == "authorization_code" && input.Code == "" {
c.Error(&common.OidcMissingAuthorizationCodeError{})
return
}
// Validate that refresh_token is provided for refresh_token grant type
if input.GrantType == "refresh_token" && input.RefreshToken == "" {
c.Error(&common.OidcMissingRefreshTokenError{})
return
}
clientID := input.ClientID clientID := input.ClientID
clientSecret := input.ClientSecret clientSecret := input.ClientSecret
@@ -141,13 +152,37 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
clientID, clientSecret, _ = c.Request.BasicAuth() clientID, clientSecret, _ = c.Request.BasicAuth()
} }
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier) idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens(
input.Code,
input.GrantType,
clientID,
clientSecret,
input.CodeVerifier,
input.RefreshToken,
)
if err != nil { if err != nil {
c.Error(err) c.Error(err)
return return
} }
c.JSON(http.StatusOK, gin.H{"id_token": idToken, "access_token": accessToken, "token_type": "Bearer"}) response := dto.OidcTokenResponseDto{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: expiresIn,
}
// Include ID token only for authorization_code grant
if idToken != "" {
response.IdToken = idToken
}
// Include refresh token if generated
if refreshToken != "" {
response.RefreshToken = refreshToken
}
c.JSON(http.StatusOK, response)
} }
// userInfoHandler godoc // userInfoHandler godoc

View File

@@ -54,6 +54,7 @@ func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
"userinfo_endpoint": appUrl + "/api/oidc/userinfo", "userinfo_endpoint": appUrl + "/api/oidc/userinfo",
"end_session_endpoint": appUrl + "/api/oidc/end-session", "end_session_endpoint": appUrl + "/api/oidc/end-session",
"jwks_uri": appUrl + "/.well-known/jwks.json", "jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"scopes_supported": []string{"openid", "profile", "email", "groups"}, "scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"}, "claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"}, "response_types_supported": []string{"code", "id_token"},

View File

@@ -48,10 +48,11 @@ type AuthorizationRequiredDto struct {
type OidcCreateTokensDto struct { type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"` GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"` Code string `form:"code"`
ClientID string `form:"client_id"` ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"` ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"` CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"`
} }
type OidcUpdateAllowedUserGroupsDto struct { type OidcUpdateAllowedUserGroupsDto struct {
@@ -64,3 +65,11 @@ type OidcLogoutDto struct {
PostLogoutRedirectUri string `form:"post_logout_redirect_uri"` PostLogoutRedirectUri string `form:"post_logout_redirect_uri"`
State string `form:"state"` State string `form:"state"`
} }
type OidcTokenResponseDto struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IdToken string `json:"id_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
}

View File

@@ -22,6 +22,7 @@ func RegisterDbCleanupJobs(db *gorm.DB) {
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions) registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens) registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes) registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
scheduler.Start() scheduler.Start()
} }
@@ -44,6 +45,11 @@ func (j *Jobs) clearOidcAuthorizationCodes() error {
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
} }
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *Jobs) clearOidcRefreshTokens() error {
return j.db.Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearAuditLogs deletes audit logs older than 90 days // ClearAuditLogs deletes audit logs older than 90 days
func (j *Jobs) clearAuditLogs() error { func (j *Jobs) clearAuditLogs() error {
return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error

View File

@@ -51,6 +51,20 @@ type OidcClient struct {
CreatedBy User CreatedBy User
} }
type OidcRefreshToken struct {
Base
Token string
ExpiresAt datatype.DateTime
Scope string
UserID string
User User
ClientID string
Client OidcClient
}
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
// Compute HasLogo field // Compute HasLogo field
c.HasLogo = c.ImageType != nil && *c.ImageType != "" c.HasLogo = c.ImageType != nil && *c.ImageType != ""

View File

@@ -145,60 +145,121 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize return isAllowedToAuthorize
} }
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) { func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (string, string, string, int, error) {
if grantType != "authorization_code" { if grantType == "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{} var client model.OidcClient
} if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", "", 0, err
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
} }
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) // Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil { if err != nil {
return "", "", &common.OidcClientSecretInvalidError{} return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
} }
}
var authorizationCodeMetaData model.OidcAuthorizationCode // If the client is public or PKCE is enabled, the code verifier must match the code challenge
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error if client.IsPublic || client.PkceEnabled {
if err != nil { if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidAuthorizationCodeError{} return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
} }
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client.IsPublic || client.PkceEnabled {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
} }
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
if err != nil {
return "", "", "", 0, err
}
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
if err != nil {
return "", "", "", 0, err
}
// Generate a refresh token
refreshToken, err := s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
if err != nil {
return "", "", "", 0, err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
s.db.Delete(&authorizationCodeMetaData)
return idToken, accessToken, refreshToken, 3600, nil
} }
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { if grantType == "refresh_token" {
return "", "", &common.OidcInvalidAuthorizationCodeError{} if refreshToken == "" {
return "", "", "", 0, &common.OidcMissingRefreshTokenError{}
}
// Get the client to check if it's public
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", "", 0, err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
if err := s.db.Preload("User").Where("token = ? AND expires_at > ?", refreshToken, datatype.DateTime(time.Now())).First(&storedRefreshToken).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
return "", "", "", 0, err
}
// Verify that the refresh token belongs to the provided client
if storedRefreshToken.ClientID != clientID {
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
// Generate a new access token
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
if err != nil {
return "", "", "", 0, err
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err := s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
if err != nil {
return "", "", "", 0, err
}
// Delete the used refresh token
s.db.Delete(&storedRefreshToken)
return "", accessToken, newRefreshToken, 3600, nil
} }
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
if err != nil {
return "", "", err
}
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
if err != nil {
return "", "", err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
s.db.Delete(&authorizationCodeMetaData)
return idToken, accessToken, nil
} }
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) { func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
@@ -567,3 +628,24 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{} return "", &common.OidcInvalidCallbackURLError{}
} }
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
}
refreshToken := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
Token: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
}
if err := s.db.Create(&refreshToken).Error; err != nil {
return "", err
}
return randomString, nil
}

View File

@@ -152,6 +152,17 @@ func (s *TestService) SeedDatabase() error {
return err return err
} }
refreshToken := model.OidcRefreshToken{
Token: "ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo",
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
}
if err := tx.Create(&refreshToken).Error; err != nil {
return err
}
accessToken := model.OneTimeAccessToken{ accessToken := model.OneTimeAccessToken{
Token: "one-time-token", Token: "one-time-token",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),

View File

@@ -0,0 +1,2 @@
DROP INDEX IF EXISTS idx_oidc_refresh_tokens_token;
DROP TABLE IF EXISTS oidc_refresh_tokens;

View File

@@ -0,0 +1,11 @@
CREATE TABLE oidc_refresh_tokens (
id UUID NOT NULL PRIMARY KEY,
created_at TIMESTAMPTZ,
token VARCHAR(255) NOT NULL UNIQUE,
expires_at TIMESTAMPTZ NOT NULL,
scope TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES users ON DELETE CASCADE,
client_id UUID NOT NULL REFERENCES oidc_clients ON DELETE CASCADE
);
CREATE INDEX idx_oidc_refresh_tokens_token ON oidc_refresh_tokens(token);

View File

@@ -0,0 +1,2 @@
DROP INDEX IF EXISTS idx_oidc_refresh_tokens_token;
DROP TABLE IF EXISTS oidc_refresh_tokens;

View File

@@ -0,0 +1,11 @@
CREATE TABLE oidc_refresh_tokens (
id TEXT NOT NULL PRIMARY KEY,
created_at DATETIME,
token TEXT NOT NULL UNIQUE,
expires_at DATETIME NOT NULL,
scope TEXT NOT NULL,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
client_id TEXT NOT NULL REFERENCES oidc_clients(id) ON DELETE CASCADE
);
CREATE INDEX idx_oidc_refresh_tokens_token ON oidc_refresh_tokens(token);

View File

@@ -69,3 +69,16 @@ export const apiKeys = [
name: 'Test API Key' name: 'Test API Key'
} }
]; ];
export const refreshTokens = [
{
token: 'ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo',
clientId: oidcClients.nextcloud.id,
expired: false
},
{
token: 'X4vqwtRyCUaq51UafHea4Fsg8Km6CAns6vp3tuX4',
clientId: oidcClients.nextcloud.id,
expired: true
}
];

View File

@@ -1,5 +1,5 @@
import test, { expect } from '@playwright/test'; import test, { expect } from '@playwright/test';
import { oidcClients } from './data'; import { oidcClients, refreshTokens } from './data';
import { cleanupBackend } from './utils/cleanup.util'; import { cleanupBackend } from './utils/cleanup.util';
import passkeyUtil from './utils/passkey.util'; import passkeyUtil from './utils/passkey.util';
@@ -134,3 +134,60 @@ test('End session with id token hint redirects to callback URL', async ({ page }
expect(redirectedCorrectly).toBeTruthy(); expect(redirectedCorrectly).toBeTruthy();
}); });
test('Successfully refresh tokens with valid refresh token', async ({ request }) => {
const { token, clientId } = refreshTokens.filter((token) => !token.expired)[0];
const clientSecret = 'w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY';
const refreshResponse = await request.post('/api/oidc/token', {
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
},
form: {
grant_type: 'refresh_token',
client_id: clientId,
refresh_token: token,
client_secret: clientSecret
}
});
// Verify we got new tokens
const tokenData = await refreshResponse.json();
expect(tokenData.access_token).toBeDefined();
expect(tokenData.refresh_token).toBeDefined();
expect(tokenData.token_type).toBe('Bearer');
expect(tokenData.expires_in).toBe(3600);
// The new refresh token should be different from the old one
expect(tokenData.refresh_token).not.toBe(token);
});
test('Using refresh token invalidates it for future use', async ({ request }) => {
const { token, clientId } = refreshTokens.filter((token) => !token.expired)[0];
const clientSecret = 'w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY';
await request.post('/api/oidc/token', {
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
},
form: {
grant_type: 'refresh_token',
client_id: clientId,
refresh_token: token,
client_secret: clientSecret
}
});
const refreshResponse = await request.post('/api/oidc/token', {
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
},
form: {
grant_type: 'refresh_token',
client_id: clientId,
refresh_token: token,
client_secret: clientSecret
}
});
expect(refreshResponse.status()).toBe(400);
});