feat: device authorization endpoint (#270)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-04-25 12:14:51 -05:00
committed by GitHub
parent 630327c979
commit 22f7d64bf0
26 changed files with 778 additions and 80 deletions

View File

@@ -1,18 +1,20 @@
package controller
import (
"errors"
"log"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
)
// NewOidcController creates a new controller for OIDC related endpoints
@@ -45,6 +47,10 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
group.POST("/oidc/device/authorize", oc.deviceAuthorizationHandler)
group.POST("/oidc/device/verify", authMiddleware.WithAdminNotRequired().Add(), oc.verifyDeviceCodeHandler)
group.GET("/oidc/device/info", authMiddleware.WithAdminNotRequired().Add(), oc.getDeviceCodeInfoHandler)
}
type OidcController struct {
@@ -144,26 +150,28 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
return
}
clientID := input.ClientID
clientSecret := input.ClientSecret
// Client id and secret can also be passed over the Authorization header
if clientID == "" && clientSecret == "" {
clientID, clientSecret, _ = c.Request.BasicAuth()
if input.ClientID == "" && input.ClientSecret == "" {
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
}
idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens(
c.Request.Context(),
input.Code,
input.GrantType,
clientID,
clientSecret,
input.CodeVerifier,
input.RefreshToken,
idToken, refreshToken, accessToken, expiresIn, err := oc.oidcService.CreateTokens(
c,
input,
)
if err != nil {
_ = c.Error(err)
switch {
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
c.JSON(http.StatusBadRequest, gin.H{
"error": "authorization_pending",
})
case errors.Is(err, &common.OidcSlowDownError{}):
c.JSON(http.StatusBadRequest, gin.H{
"error": "slow_down",
})
default:
_ = c.Error(err)
}
return
}
@@ -613,3 +621,60 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
c.JSON(http.StatusOK, oidcClientDto)
}
func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
var input dto.OidcDeviceAuthorizationRequestDto
if err := c.ShouldBind(&input); err != nil {
_ = c.Error(err)
return
}
// Client id and secret can also be passed over the Authorization header
if input.ClientID == "" && input.ClientSecret == "" {
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
}
response, err := oc.oidcService.CreateDeviceAuthorization(input)
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, response)
}
func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
userCode := c.Query("code")
if userCode == "" {
_ = c.Error(&common.ValidationError{Message: "code is required"})
return
}
// Get IP address and user agent from the request context
ipAddress := c.ClientIP()
userAgent := c.Request.UserAgent()
err := oc.oidcService.VerifyDeviceCode(c, userCode, c.GetString("userID"), ipAddress, userAgent)
if err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) {
userCode := c.Query("code")
if userCode == "" {
_ = c.Error(&common.ValidationError{Message: "code is required"})
return
}
deviceCodeInfo, err := oc.oidcService.GetDeviceCodeInfo(c, userCode, c.GetString("userID"))
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, deviceCodeInfo)
}