diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index ea45acfb..510b073d 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -129,9 +129,6 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex // @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token" // @Router /api/oidc/token [post] func (oc *OidcController) createTokensHandler(c *gin.Context) { - // Disable cors for this endpoint - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - var input dto.OidcCreateTokensDto if err := c.ShouldBind(&input); err != nil { _ = c.Error(err) diff --git a/backend/internal/middleware/cors.go b/backend/internal/middleware/cors.go index 91fe9026..6b2a936c 100644 --- a/backend/internal/middleware/cors.go +++ b/backend/internal/middleware/cors.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/pocket-id/pocket-id/backend/internal/common" ) type CorsMiddleware struct{} @@ -15,17 +14,21 @@ func NewCorsMiddleware() *CorsMiddleware { func (m *CorsMiddleware) Add() gin.HandlerFunc { return func(c *gin.Context) { - // Allow all origins for the token endpoint - switch c.FullPath() { - case "/api/oidc/token", "/api/oidc/introspect": - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - default: - c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL) + path := c.FullPath() + if path == "" { + // The router doesn't map preflight requests, so we need to use the raw URL path + path = c.Request.URL.Path } - c.Writer.Header().Set("Access-Control-Allow-Headers", "*") - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") + if !isCorsPath(path) { + c.Next() + return + } + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST") + + // Preflight request if c.Request.Method == http.MethodOptions { c.AbortWithStatus(204) return @@ -34,3 +37,17 @@ func (m *CorsMiddleware) Add() gin.HandlerFunc { c.Next() } } + +func isCorsPath(path string) bool { + switch path { + case "/api/oidc/token", + "/api/oidc/userinfo", + "/oidc/end-session", + "/api/oidc/introspect", + "/.well-known/jwks.json", + "/.well-known/openid-configuration": + return true + default: + return false + } +}