feat: device authorization endpoint (#270)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-04-25 12:14:51 -05:00
committed by GitHub
parent 630327c979
commit 22f7d64bf0
26 changed files with 778 additions and 80 deletions

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"mime/multipart"
"os"
"regexp"
@@ -180,45 +181,99 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize
}
func (s *OidcService) CreateTokens(ctx context.Context, code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
switch grantType {
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
switch input.GrantType {
case "authorization_code":
return s.createTokenFromAuthorizationCode(ctx, code, clientID, clientSecret, codeVerifier)
return s.createTokenFromAuthorizationCode(ctx, input.Code, input.ClientID, input.ClientSecret, input.CodeVerifier)
case "refresh_token":
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, refreshToken, clientID, clientSecret)
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, input.RefreshToken, input.ClientID, input.ClientSecret)
return "", accessToken, newRefreshToken, exp, err
case "urn:ietf:params:oauth:grant-type:device_code":
return s.createTokenFromDeviceCode(ctx, input.DeviceCode, input.ClientID, input.ClientSecret)
default:
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
}
}
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, clientID string, clientSecret string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
if err != nil {
return "", "", "", 0, err
}
// Get the device authorization from database with explicit query conditions
var deviceAuth model.OidcDeviceCode
if err := tx.WithContext(ctx).Preload("User").Where("device_code = ? AND client_id = ?", deviceCode, clientID).First(&deviceAuth).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", "", 0, &common.OidcInvalidDeviceCodeError{}
}
return "", "", "", 0, err
}
// Check if device code has expired
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
return "", "", "", 0, &common.OidcDeviceCodeExpiredError{}
}
// Check if device code has been authorized
if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil {
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
}
// Get user claims for the ID token - ensure UserID is not nil
if deviceAuth.UserID == nil {
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
}
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, clientID, tx)
if err != nil {
return "", "", "", 0, err
}
// Explicitly use the input clientID for the audience claim to ensure consistency
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, "")
if err != nil {
return "", "", "", 0, err
}
refreshToken, err = s.createRefreshToken(ctx, clientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
if err != nil {
return "", "", "", 0, err
}
accessToken, err = s.jwtService.GenerateOauthAccessToken(deviceAuth.User, clientID)
if err != nil {
return "", "", "", 0, err
}
// Delete the used device code
if err := tx.WithContext(ctx).Delete(&deviceAuth).Error; err != nil {
return "", "", "", 0, err
}
if err := tx.Commit().Error; err != nil {
return "", "", "", 0, err
}
return idToken, accessToken, refreshToken, 3600, nil
}
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
client, err := s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
if err != nil {
return "", "", "", 0, err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = tx.
WithContext(ctx).
@@ -287,28 +342,11 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
tx.Rollback()
}()
// Get the client to check if it's public
var client model.OidcClient
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
if err != nil {
return "", "", 0, err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
err = tx.
@@ -363,19 +401,9 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
// Get the client to check if we are authorized.
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
// Verify the client secret. This endpoint may not be used by public clients.
if client.IsPublic {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
return introspectDto, &common.OidcClientSecretInvalidError{}
_, err = s.VerifyClientCredentials(context.Background(), clientID, clientSecret, s.db)
if err != nil {
return introspectDto, err
}
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
@@ -968,6 +996,162 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{}
}
func (s *OidcService) CreateDeviceAuthorization(input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.VerifyClientCredentials(context.Background(), input.ClientID, input.ClientSecret, s.db)
if err != nil {
return nil, err
}
// Generate codes
deviceCode, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return nil, err
}
userCode, err := utils.GenerateRandomAlphanumericString(8)
if err != nil {
return nil, err
}
// Create device authorization
deviceAuth := &model.OidcDeviceCode{
DeviceCode: deviceCode,
UserCode: userCode,
Scope: input.Scope,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
IsAuthorized: false,
ClientID: client.ID,
}
if err := s.db.Create(deviceAuth).Error; err != nil {
return nil, err
}
return &dto.OidcDeviceAuthorizationResponseDto{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: common.EnvConfig.AppURL + "/device",
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
ExpiresIn: 900, // 15 minutes
Interval: 5,
}, nil
}
func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, ipAddress string, userAgent string) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var deviceAuth model.OidcDeviceCode
if err := tx.WithContext(ctx).Preload("Client.AllowedUserGroups").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
log.Printf("Error finding device code with user_code %s: %v", userCode, err)
return err
}
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
return &common.OidcDeviceCodeExpiredError{}
}
// Check if the user group is allowed to authorize the client
var user model.User
if err := tx.WithContext(ctx).Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
return err
}
if !s.IsUserGroupAllowedToAuthorize(user, deviceAuth.Client) {
return &common.OidcAccessDeniedError{}
}
if err := tx.WithContext(ctx).Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
log.Printf("Error finding device code with user_code %s: %v", userCode, err)
return err
}
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
return &common.OidcDeviceCodeExpiredError{}
}
deviceAuth.UserID = &userID
deviceAuth.IsAuthorized = true
if err := tx.WithContext(ctx).Save(&deviceAuth).Error; err != nil {
log.Printf("Error saving device auth: %v", err)
return err
}
// Verify the update was successful
var verifiedAuth model.OidcDeviceCode
if err := tx.WithContext(ctx).First(&verifiedAuth, "device_code = ?", deviceAuth.DeviceCode).Error; err != nil {
log.Printf("Error verifying update: %v", err)
return err
}
// Create user authorization if needed
hasAuthorizedClient, err := s.hasAuthorizedClientInternal(ctx, deviceAuth.ClientID, userID, deviceAuth.Scope, tx)
if err != nil {
return err
}
if !hasAuthorizedClient {
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: deviceAuth.ClientID,
Scope: deviceAuth.Scope,
}
if err := tx.WithContext(ctx).Create(&userAuthorizedClient).Error; err != nil {
if !errors.Is(err, gorm.ErrDuplicatedKey) {
return err
}
// If duplicate, update scope
if err := tx.WithContext(ctx).Model(&model.UserAuthorizedOidcClient{}).
Where("user_id = ? AND client_id = ?", userID, deviceAuth.ClientID).
Update("scope", deviceAuth.Scope).Error; err != nil {
return err
}
}
s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
} else {
s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
}
return tx.Commit().Error
}
func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, userID string) (*dto.DeviceCodeInfoDto, error) {
var deviceAuth model.OidcDeviceCode
if err := s.db.Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, &common.OidcInvalidDeviceCodeError{}
}
return nil, err
}
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
return nil, &common.OidcDeviceCodeExpiredError{}
}
// Check if the user has already authorized this client with this scope
hasAuthorizedClient := false
if userID != "" {
var err error
hasAuthorizedClient, err = s.HasAuthorizedClient(ctx, deviceAuth.ClientID, userID, deviceAuth.Scope)
if err != nil {
return nil, err
}
}
return &dto.DeviceCodeInfoDto{
Client: dto.OidcClientMetaDataDto{
ID: deviceAuth.Client.ID,
Name: deviceAuth.Client.Name,
HasLogo: deviceAuth.Client.HasLogo,
},
Scope: deviceAuth.Scope,
AuthorizationRequired: !hasAuthorizedClient,
}, nil
}
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
@@ -996,3 +1180,25 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
return refreshToken, nil
}
func (s *OidcService) VerifyClientCredentials(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
if clientID == "" {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
}
var client model.OidcClient
err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err
}
if !client.IsPublic {
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
}
}
return client, nil
}