feat: implement token introspection (#405)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Andreas Schneider
2025-04-09 09:18:03 +02:00
committed by GitHub
parent 8d6c1e5c08
commit 7e5d16be9b
9 changed files with 416 additions and 14 deletions

View File

@@ -6,14 +6,13 @@ import (
"net/url"
"strings"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
)
// NewOidcController creates a new controller for OIDC related endpoints
@@ -31,6 +30,7 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.POST("/oidc/userinfo", oc.userInfoHandler)
group.POST("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.GET("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.POST("/oidc/introspect", oc.introspectTokenHandler)
group.GET("/oidc/clients", authMiddleware.Add(), oc.listClientsHandler)
group.POST("/oidc/clients", authMiddleware.Add(), oc.createClientHandler)
@@ -291,6 +291,38 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
// Implementation is the same as GET
}
// introspectToken godoc
// @Summary Introspect OIDC tokens
// @Description Pass an access_token to verify if it is considered valid.
// @Tags OIDC
// @Produce json
// @Param token formData string true "The token to be introspected."
// @Success 200 {object} dto.OidcIntrospectionResponseDto "Response with the introspection result."
// @Router /api/oidc/introspect [post]
func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
var input dto.OidcIntrospectDto
if err := c.ShouldBind(&input); err != nil {
_ = c.Error(err)
return
}
// Client id and secret have to be passed over the Authorization header. This kind of
// authentication allows us to keep the endpoint protected (since it could be used to
// find valid tokens) while still allowing it to be used by an application that is
// supposed to interact with our IdP (since that needs to have a client_id
// and client_secret anyway).
clientID, clientSecret, _ := c.Request.BasicAuth()
response, err := oc.oidcService.IntrospectToken(clientID, clientSecret, input.Token)
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, response)
}
// getClientMetaDataHandler godoc
// @Summary Get client metadata
// @Description Get OIDC client metadata for discovery and configuration

View File

@@ -74,6 +74,7 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
"token_endpoint": appUrl + "/api/oidc/token",
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
"end_session_endpoint": appUrl + "/api/oidc/end-session",
"introspection_endpoint": appUrl + "/api/oidc/introspect",
"jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"scopes_supported": []string{"openid", "profile", "email", "groups"},

View File

@@ -55,6 +55,10 @@ type OidcCreateTokensDto struct {
RefreshToken string `form:"refresh_token"`
}
type OidcIntrospectDto struct {
Token string `form:"token" binding:"required"`
}
type OidcUpdateAllowedUserGroupsDto struct {
UserGroupIDs []string `json:"userGroupIds" binding:"required"`
}
@@ -73,3 +77,16 @@ type OidcTokenResponseDto struct {
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
}
type OidcIntrospectionResponseDto struct {
Active bool `json:"active"`
TokenType string `json:"token_type,omitempty"`
Scope string `json:"scope,omitempty"`
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
Subject string `json:"sub,omitempty"`
Audience []string `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
Identifier string `json:"jti,omitempty"`
}

View File

@@ -15,7 +15,14 @@ func NewCorsMiddleware() *CorsMiddleware {
func (m *CorsMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL)
// Allow all origins for the token endpoint
switch c.FullPath() {
case "/api/oidc/token", "/api/oidc/introspect":
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
default:
c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL)
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")

View File

@@ -11,8 +11,11 @@ import (
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
@@ -37,6 +40,12 @@ const (
// This may be omitted on non-admin tokens
IsAdminClaim = "isAdmin"
// AccessTokenJWTType is the media type for access tokens
AccessTokenJWTType = "AT+JWT"
// IDTokenJWTType is the media type for ID tokens
IDTokenJWTType = "ID+JWT"
// Acceptable clock skew for verifying tokens
clockSkew = time.Minute
)
@@ -247,8 +256,13 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
}
}
headers, err := CreateTokenTypeHeader(IDTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set token type: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
@@ -285,6 +299,11 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
return nil, fmt.Errorf("failed to parse token: %w", err)
}
err = VerifyTokenTypeHeader(tokenString, IDTokenJWTType)
if err != nil {
return nil, fmt.Errorf("failed to verify token type: %w", err)
}
return token, nil
}
@@ -305,8 +324,13 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
headers, err := CreateTokenTypeHeader(AccessTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set token type: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
@@ -327,6 +351,11 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
return nil, fmt.Errorf("failed to parse token: %w", err)
}
err = VerifyTokenTypeHeader(tokenString, AccessTokenJWTType)
if err != nil {
return nil, fmt.Errorf("failed to verify token type: %w", err)
}
return token, nil
}
@@ -481,6 +510,17 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
return isAdmin, err
}
// CreateTokenTypeHeader creates a new JWS header with the given token type
func CreateTokenTypeHeader(tokenType string) (jws.Headers, error) {
headers := jws.NewHeaders()
err := headers.Set(jws.TypeKey, tokenType)
if err != nil {
return nil, fmt.Errorf("failed to set token type: %w", err)
}
return headers, nil
}
// SetIsAdmin sets the "isAdmin" claim in the token
func SetIsAdmin(token jwt.Token, isAdmin bool) error {
// Only set if true
@@ -495,3 +535,37 @@ func SetIsAdmin(token jwt.Token, isAdmin bool) error {
func SetAudienceString(token jwt.Token, audience string) error {
return token.Set(jwt.AudienceKey, audience)
}
// VerifyTokenTypeHeader verifies that the "typ" header in the token matches the expected type
func VerifyTokenTypeHeader(tokenBytes string, expectedTokenType string) error {
// Parse the raw token string purely as a JWS message structure
// We don't need to verify the signature at this stage, just inspect headers.
msg, err := jws.Parse([]byte(tokenBytes))
if err != nil {
return fmt.Errorf("failed to parse token as JWS message: %w", err)
}
// Get the list of signatures attached to the message. Usually just one for JWT.
signatures := msg.Signatures()
if len(signatures) == 0 {
return errors.New("JWS message contains no signatures")
}
protectedHeaders := signatures[0].ProtectedHeaders()
if protectedHeaders == nil {
return fmt.Errorf("JWS signature has no protected headers")
}
// Retrieve the 'typ' header value from the PROTECTED headers.
var typHeaderValue string
err = protectedHeaders.Get(jws.TypeKey, &typHeaderValue)
if err != nil {
return fmt.Errorf("token is missing required protected header '%s'", jws.TypeKey)
}
if !strings.EqualFold(typHeaderValue, expectedTokenType) {
return fmt.Errorf("'%s' header mismatch: expected '%s', got '%s'", jws.TypeKey, expectedTokenType, typHeaderValue)
}
return nil
}

View File

@@ -6,12 +6,15 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"fmt"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
@@ -651,8 +654,13 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
}
// Create headers with the specified type
hdrs := jws.NewHeaders()
err = hdrs.Set(jws.TypeKey, "ID+JWT")
require.NoError(t, err, "Failed to set header type")
// Sign the token
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
require.NoError(t, err, "Failed to sign token")
tokenString := string(signed)
@@ -1172,6 +1180,63 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
})
}
func TestVerifyTokenTypeHeader(t *testing.T) {
mockConfig := &AppConfigService{}
tempDir := t.TempDir()
// Helper function to create a token with a specific type header
createTokenWithType := func(tokenType string) (string, error) {
// Create a simple JWT token
token := jwt.New()
err := token.Set("test_claim", "test_value")
if err != nil {
return "", fmt.Errorf("failed to set claim: %w", err)
}
// Create headers with the specified type
hdrs := jws.NewHeaders()
if tokenType != "" {
err = hdrs.Set(jws.TypeKey, tokenType)
if err != nil {
return "", fmt.Errorf("failed to set type header: %w", err)
}
}
// Sign the token with the headers
service := &JwtService{}
err = service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
}
t.Run("succeeds when token type matches expected type", func(t *testing.T) {
// Create a token with "JWT" type
tokenString, err := createTokenWithType("JWT")
require.NoError(t, err, "Failed to create test token")
// Verify the token type
err = VerifyTokenTypeHeader(tokenString, "JWT")
assert.NoError(t, err, "Should accept token with matching type")
})
t.Run("fails when token type doesn't match expected type", func(t *testing.T) {
// Create a token with "AT+JWT" type
tokenString, err := createTokenWithType("AT+JWT")
require.NoError(t, err, "Failed to create test token")
// Verify the token with different expected type
err = VerifyTokenTypeHeader(tokenString, "JWT")
require.Error(t, err, "Should reject token with non-matching type")
assert.Contains(t, err.Error(), "header mismatch: expected 'JWT', got 'AT+JWT'")
})
}
func importKey(t *testing.T, privateKeyRaw any, path string) string {
t.Helper()

View File

@@ -14,6 +14,8 @@ import (
"strings"
"time"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
@@ -356,6 +358,93 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
return accessToken, newRefreshToken, 3600, nil
}
func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
if clientID == "" || clientSecret == "" {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
// Get the client to check if we are authorized.
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
// Verify the client secret. This endpoint may not be used by public clients.
if client.IsPublic {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
if err != nil {
if errors.Is(err, jwt.ParseError()) {
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
return s.introspectRefreshToken(tokenString)
}
// Every failure we get means the token is invalid. Nothing more to do with the error.
introspectDto.Active = false
return introspectDto, nil
}
introspectDto.Active = true
introspectDto.TokenType = "access_token"
if token.Has("scope") {
var asString string
var asStrings []string
if err := token.Get("scope", &asString); err == nil {
introspectDto.Scope = asString
} else if err := token.Get("scope", &asStrings); err == nil {
introspectDto.Scope = strings.Join(asStrings, " ")
}
}
if expiration, hasExpiration := token.Expiration(); hasExpiration {
introspectDto.Expiration = expiration.Unix()
}
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
introspectDto.IssuedAt = issuedAt.Unix()
}
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
introspectDto.NotBefore = notBefore.Unix()
}
if subject, hasSubject := token.Subject(); hasSubject {
introspectDto.Subject = subject
}
if audience, hasAudience := token.Audience(); hasAudience {
introspectDto.Audience = audience
}
if issuer, hasIssuer := token.Issuer(); hasIssuer {
introspectDto.Issuer = issuer
}
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
introspectDto.Identifier = identifier
}
return introspectDto, nil
}
func (s *OidcService) introspectRefreshToken(refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
var storedRefreshToken model.OidcRefreshToken
err = s.db.Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
introspectDto.Active = false
return introspectDto, nil
}
return introspectDto, err
}
introspectDto.Active = true
introspectDto.TokenType = "refresh_token"
return introspectDto, nil
}
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
return s.getClientInternal(ctx, clientID, s.db)
}