mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-28 17:25:04 +03:00
feat: device authorization endpoint (#270)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
@@ -296,3 +296,51 @@ func (e *UserDisabledError) Error() string {
|
||||
func (e *UserDisabledError) HttpStatusCode() int {
|
||||
return http.StatusForbidden
|
||||
}
|
||||
|
||||
type ValidationError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e *ValidationError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
||||
type OidcDeviceCodeExpiredError struct{}
|
||||
|
||||
func (e *OidcDeviceCodeExpiredError) Error() string {
|
||||
return "device code has expired"
|
||||
}
|
||||
func (e *OidcDeviceCodeExpiredError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
||||
type OidcInvalidDeviceCodeError struct{}
|
||||
|
||||
func (e *OidcInvalidDeviceCodeError) Error() string {
|
||||
return "invalid device code"
|
||||
}
|
||||
func (e *OidcInvalidDeviceCodeError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
||||
type OidcSlowDownError struct{}
|
||||
|
||||
func (e *OidcSlowDownError) Error() string {
|
||||
return "polling too frequently"
|
||||
}
|
||||
func (e *OidcSlowDownError) HttpStatusCode() int {
|
||||
return http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
type OidcAuthorizationPendingError struct{}
|
||||
|
||||
func (e *OidcAuthorizationPendingError) Error() string {
|
||||
return "authorization is still pending"
|
||||
}
|
||||
func (e *OidcAuthorizationPendingError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
||||
)
|
||||
|
||||
// NewOidcController creates a new controller for OIDC related endpoints
|
||||
@@ -45,6 +47,10 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
|
||||
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
|
||||
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
|
||||
group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
|
||||
|
||||
group.POST("/oidc/device/authorize", oc.deviceAuthorizationHandler)
|
||||
group.POST("/oidc/device/verify", authMiddleware.WithAdminNotRequired().Add(), oc.verifyDeviceCodeHandler)
|
||||
group.GET("/oidc/device/info", authMiddleware.WithAdminNotRequired().Add(), oc.getDeviceCodeInfoHandler)
|
||||
}
|
||||
|
||||
type OidcController struct {
|
||||
@@ -144,26 +150,28 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
clientID := input.ClientID
|
||||
clientSecret := input.ClientSecret
|
||||
|
||||
// Client id and secret can also be passed over the Authorization header
|
||||
if clientID == "" && clientSecret == "" {
|
||||
clientID, clientSecret, _ = c.Request.BasicAuth()
|
||||
if input.ClientID == "" && input.ClientSecret == "" {
|
||||
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
||||
}
|
||||
|
||||
idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens(
|
||||
c.Request.Context(),
|
||||
input.Code,
|
||||
input.GrantType,
|
||||
clientID,
|
||||
clientSecret,
|
||||
input.CodeVerifier,
|
||||
input.RefreshToken,
|
||||
idToken, refreshToken, accessToken, expiresIn, err := oc.oidcService.CreateTokens(
|
||||
c,
|
||||
input,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
switch {
|
||||
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "authorization_pending",
|
||||
})
|
||||
case errors.Is(err, &common.OidcSlowDownError{}):
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "slow_down",
|
||||
})
|
||||
default:
|
||||
_ = c.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -613,3 +621,60 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, oidcClientDto)
|
||||
}
|
||||
|
||||
func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
|
||||
var input dto.OidcDeviceAuthorizationRequestDto
|
||||
if err := c.ShouldBind(&input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Client id and secret can also be passed over the Authorization header
|
||||
if input.ClientID == "" && input.ClientSecret == "" {
|
||||
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
||||
}
|
||||
|
||||
response, err := oc.oidcService.CreateDeviceAuthorization(input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
|
||||
userCode := c.Query("code")
|
||||
if userCode == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "code is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get IP address and user agent from the request context
|
||||
ipAddress := c.ClientIP()
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
err := oc.oidcService.VerifyDeviceCode(c, userCode, c.GetString("userID"), ipAddress, userAgent)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) {
|
||||
userCode := c.Query("code")
|
||||
if userCode == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "code is required"})
|
||||
return
|
||||
}
|
||||
|
||||
deviceCodeInfo, err := oc.oidcService.GetDeviceCodeInfo(c, userCode, c.GetString("userID"))
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, deviceCodeInfo)
|
||||
}
|
||||
|
||||
@@ -75,8 +75,9 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
||||
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
|
||||
"end_session_endpoint": appUrl + "/api/oidc/end-session",
|
||||
"introspection_endpoint": appUrl + "/api/oidc/introspect",
|
||||
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
||||
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
||||
"response_types_supported": []string{"code", "id_token"},
|
||||
|
||||
@@ -49,6 +49,7 @@ type AuthorizationRequiredDto struct {
|
||||
type OidcCreateTokensDto struct {
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code"`
|
||||
DeviceCode string `form:"device_code"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
CodeVerifier string `form:"code_verifier"`
|
||||
@@ -90,3 +91,32 @@ type OidcIntrospectionResponseDto struct {
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Identifier string `json:"jti,omitempty"`
|
||||
}
|
||||
|
||||
type OidcDeviceAuthorizationRequestDto struct {
|
||||
ClientID string `form:"client_id" binding:"required"`
|
||||
Scope string `form:"scope" binding:"required"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
}
|
||||
|
||||
type OidcDeviceAuthorizationResponseDto struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
RequiresAuthorization bool `json:"requires_authorization"`
|
||||
}
|
||||
|
||||
type OidcDeviceTokenRequestDto struct {
|
||||
GrantType string `form:"grant_type" binding:"required,eq=urn:ietf:params:oauth:grant-type:device_code"`
|
||||
DeviceCode string `form:"device_code" binding:"required"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
}
|
||||
|
||||
type DeviceCodeInfoDto struct {
|
||||
Scope string `json:"scope"`
|
||||
AuthorizationRequired bool `json:"authorizationRequired"`
|
||||
Client OidcClientMetaDataDto `json:"client"`
|
||||
}
|
||||
|
||||
@@ -26,10 +26,12 @@ type AuditLogData map[string]string //nolint:recvcheck
|
||||
type AuditLogEvent string //nolint:recvcheck
|
||||
|
||||
const (
|
||||
AuditLogEventSignIn AuditLogEvent = "SIGN_IN"
|
||||
AuditLogEventOneTimeAccessTokenSignIn AuditLogEvent = "TOKEN_SIGN_IN"
|
||||
AuditLogEventClientAuthorization AuditLogEvent = "CLIENT_AUTHORIZATION"
|
||||
AuditLogEventNewClientAuthorization AuditLogEvent = "NEW_CLIENT_AUTHORIZATION"
|
||||
AuditLogEventSignIn AuditLogEvent = "SIGN_IN"
|
||||
AuditLogEventOneTimeAccessTokenSignIn AuditLogEvent = "TOKEN_SIGN_IN"
|
||||
AuditLogEventClientAuthorization AuditLogEvent = "CLIENT_AUTHORIZATION"
|
||||
AuditLogEventNewClientAuthorization AuditLogEvent = "NEW_CLIENT_AUTHORIZATION"
|
||||
AuditLogEventDeviceCodeAuthorization AuditLogEvent = "DEVICE_CODE_AUTHORIZATION"
|
||||
AuditLogEventNewDeviceCodeAuthorization AuditLogEvent = "NEW_DEVICE_CODE_AUTHORIZATION"
|
||||
)
|
||||
|
||||
// Scan and Value methods for GORM to handle the custom type
|
||||
|
||||
@@ -87,3 +87,17 @@ func (cu *UrlList) Scan(value interface{}) error {
|
||||
func (cu UrlList) Value() (driver.Value, error) {
|
||||
return json.Marshal(cu)
|
||||
}
|
||||
|
||||
type OidcDeviceCode struct {
|
||||
Base
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Scope string
|
||||
ExpiresAt datatype.DateTime
|
||||
IsAuthorized bool
|
||||
|
||||
UserID *string
|
||||
User User
|
||||
ClientID string
|
||||
Client OidcClient
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -180,45 +181,99 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
|
||||
return isAllowedToAuthorize
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateTokens(ctx context.Context, code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
|
||||
switch grantType {
|
||||
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
|
||||
switch input.GrantType {
|
||||
case "authorization_code":
|
||||
return s.createTokenFromAuthorizationCode(ctx, code, clientID, clientSecret, codeVerifier)
|
||||
return s.createTokenFromAuthorizationCode(ctx, input.Code, input.ClientID, input.ClientSecret, input.CodeVerifier)
|
||||
case "refresh_token":
|
||||
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, refreshToken, clientID, clientSecret)
|
||||
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, input.RefreshToken, input.ClientID, input.ClientSecret)
|
||||
return "", accessToken, newRefreshToken, exp, err
|
||||
case "urn:ietf:params:oauth:grant-type:device_code":
|
||||
return s.createTokenFromDeviceCode(ctx, input.DeviceCode, input.ClientID, input.ClientSecret)
|
||||
default:
|
||||
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, clientID string, clientSecret string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Get the device authorization from database with explicit query conditions
|
||||
var deviceAuth model.OidcDeviceCode
|
||||
if err := tx.WithContext(ctx).Preload("User").Where("device_code = ? AND client_id = ?", deviceCode, clientID).First(&deviceAuth).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", "", "", 0, &common.OidcInvalidDeviceCodeError{}
|
||||
}
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Check if device code has expired
|
||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||
return "", "", "", 0, &common.OidcDeviceCodeExpiredError{}
|
||||
}
|
||||
|
||||
// Check if device code has been authorized
|
||||
if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil {
|
||||
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
|
||||
}
|
||||
|
||||
// Get user claims for the ID token - ensure UserID is not nil
|
||||
if deviceAuth.UserID == nil {
|
||||
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
|
||||
}
|
||||
|
||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, clientID, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Explicitly use the input clientID for the audience claim to ensure consistency
|
||||
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, "")
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
refreshToken, err = s.createRefreshToken(ctx, clientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(deviceAuth.User, clientID)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Delete the used device code
|
||||
if err := tx.WithContext(ctx).Delete(&deviceAuth).Error; err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
return idToken, accessToken, refreshToken, 3600, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
client, err := s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Verify the client secret if the client is not public
|
||||
if !client.IsPublic {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
@@ -287,28 +342,11 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// Get the client to check if it's public
|
||||
var client model.OidcClient
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
// Verify the client secret if the client is not public
|
||||
if !client.IsPublic {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", "", 0, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return "", "", 0, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify refresh token
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = tx.
|
||||
@@ -363,19 +401,9 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Get the client to check if we are authorized.
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return introspectDto, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
|
||||
// Verify the client secret. This endpoint may not be used by public clients.
|
||||
if client.IsPublic {
|
||||
return introspectDto, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
|
||||
return introspectDto, &common.OidcClientSecretInvalidError{}
|
||||
_, err = s.VerifyClientCredentials(context.Background(), clientID, clientSecret, s.db)
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
|
||||
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
|
||||
@@ -968,6 +996,162 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
|
||||
return "", &common.OidcInvalidCallbackURLError{}
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateDeviceAuthorization(input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||
client, err := s.VerifyClientCredentials(context.Background(), input.ClientID, input.ClientSecret, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate codes
|
||||
deviceCode, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userCode, err := utils.GenerateRandomAlphanumericString(8)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create device authorization
|
||||
deviceAuth := &model.OidcDeviceCode{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
Scope: input.Scope,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
|
||||
IsAuthorized: false,
|
||||
ClientID: client.ID,
|
||||
}
|
||||
|
||||
if err := s.db.Create(deviceAuth).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.OidcDeviceAuthorizationResponseDto{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: common.EnvConfig.AppURL + "/device",
|
||||
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
|
||||
ExpiresIn: 900, // 15 minutes
|
||||
Interval: 5,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, ipAddress string, userAgent string) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var deviceAuth model.OidcDeviceCode
|
||||
if err := tx.WithContext(ctx).Preload("Client.AllowedUserGroups").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
|
||||
log.Printf("Error finding device code with user_code %s: %v", userCode, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||
return &common.OidcDeviceCodeExpiredError{}
|
||||
}
|
||||
|
||||
// Check if the user group is allowed to authorize the client
|
||||
var user model.User
|
||||
if err := tx.WithContext(ctx).Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.IsUserGroupAllowedToAuthorize(user, deviceAuth.Client) {
|
||||
return &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
if err := tx.WithContext(ctx).Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
|
||||
log.Printf("Error finding device code with user_code %s: %v", userCode, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||
return &common.OidcDeviceCodeExpiredError{}
|
||||
}
|
||||
|
||||
deviceAuth.UserID = &userID
|
||||
deviceAuth.IsAuthorized = true
|
||||
|
||||
if err := tx.WithContext(ctx).Save(&deviceAuth).Error; err != nil {
|
||||
log.Printf("Error saving device auth: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the update was successful
|
||||
var verifiedAuth model.OidcDeviceCode
|
||||
if err := tx.WithContext(ctx).First(&verifiedAuth, "device_code = ?", deviceAuth.DeviceCode).Error; err != nil {
|
||||
log.Printf("Error verifying update: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create user authorization if needed
|
||||
hasAuthorizedClient, err := s.hasAuthorizedClientInternal(ctx, deviceAuth.ClientID, userID, deviceAuth.Scope, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !hasAuthorizedClient {
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: deviceAuth.ClientID,
|
||||
Scope: deviceAuth.Scope,
|
||||
}
|
||||
|
||||
if err := tx.WithContext(ctx).Create(&userAuthorizedClient).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return err
|
||||
}
|
||||
// If duplicate, update scope
|
||||
if err := tx.WithContext(ctx).Model(&model.UserAuthorizedOidcClient{}).
|
||||
Where("user_id = ? AND client_id = ?", userID, deviceAuth.ClientID).
|
||||
Update("scope", deviceAuth.Scope).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
|
||||
} else {
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, userID string) (*dto.DeviceCodeInfoDto, error) {
|
||||
var deviceAuth model.OidcDeviceCode
|
||||
if err := s.db.Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, &common.OidcInvalidDeviceCodeError{}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||
return nil, &common.OidcDeviceCodeExpiredError{}
|
||||
}
|
||||
|
||||
// Check if the user has already authorized this client with this scope
|
||||
hasAuthorizedClient := false
|
||||
if userID != "" {
|
||||
var err error
|
||||
hasAuthorizedClient, err = s.HasAuthorizedClient(ctx, deviceAuth.ClientID, userID, deviceAuth.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.DeviceCodeInfoDto{
|
||||
Client: dto.OidcClientMetaDataDto{
|
||||
ID: deviceAuth.Client.ID,
|
||||
Name: deviceAuth.Client.Name,
|
||||
HasLogo: deviceAuth.Client.HasLogo,
|
||||
},
|
||||
Scope: deviceAuth.Scope,
|
||||
AuthorizationRequired: !hasAuthorizedClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
|
||||
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
|
||||
if err != nil {
|
||||
@@ -996,3 +1180,25 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
|
||||
return refreshToken, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) VerifyClientCredentials(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
if clientID == "" {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
if !client.IsPublic {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
|
||||
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE oidc_device_codes;
|
||||
@@ -0,0 +1,12 @@
|
||||
CREATE TABLE oidc_device_codes
|
||||
(
|
||||
id UUID NOT NULL PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ,
|
||||
device_code TEXT NOT NULL UNIQUE,
|
||||
user_code TEXT NOT NULL UNIQUE,
|
||||
scope TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
is_authorized BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
user_id UUID REFERENCES users ON DELETE CASCADE,
|
||||
client_id UUID NOT NULL REFERENCES oidc_clients ON DELETE CASCADE
|
||||
);
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE oidc_device_codes;
|
||||
@@ -0,0 +1,12 @@
|
||||
CREATE TABLE oidc_device_codes
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
device_code TEXT NOT NULL UNIQUE,
|
||||
user_code TEXT NOT NULL UNIQUE,
|
||||
scope TEXT NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
is_authorized BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
user_id TEXT REFERENCES users ON DELETE CASCADE,
|
||||
client_id TEXT NOT NULL REFERENCES oidc_clients ON DELETE CASCADE
|
||||
);
|
||||
Reference in New Issue
Block a user