feat: implement token introspection (#405)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Andreas Schneider
2025-04-09 09:18:03 +02:00
committed by GitHub
parent 8d6c1e5c08
commit 7e5d16be9b
9 changed files with 416 additions and 14 deletions

View File

@@ -11,8 +11,11 @@ import (
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
@@ -37,6 +40,12 @@ const (
// This may be omitted on non-admin tokens
IsAdminClaim = "isAdmin"
// AccessTokenJWTType is the media type for access tokens
AccessTokenJWTType = "AT+JWT"
// IDTokenJWTType is the media type for ID tokens
IDTokenJWTType = "ID+JWT"
// Acceptable clock skew for verifying tokens
clockSkew = time.Minute
)
@@ -247,8 +256,13 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
}
}
headers, err := CreateTokenTypeHeader(IDTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set token type: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
@@ -285,6 +299,11 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
return nil, fmt.Errorf("failed to parse token: %w", err)
}
err = VerifyTokenTypeHeader(tokenString, IDTokenJWTType)
if err != nil {
return nil, fmt.Errorf("failed to verify token type: %w", err)
}
return token, nil
}
@@ -305,8 +324,13 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
headers, err := CreateTokenTypeHeader(AccessTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set token type: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
@@ -327,6 +351,11 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
return nil, fmt.Errorf("failed to parse token: %w", err)
}
err = VerifyTokenTypeHeader(tokenString, AccessTokenJWTType)
if err != nil {
return nil, fmt.Errorf("failed to verify token type: %w", err)
}
return token, nil
}
@@ -481,6 +510,17 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
return isAdmin, err
}
// CreateTokenTypeHeader creates a new JWS header with the given token type
func CreateTokenTypeHeader(tokenType string) (jws.Headers, error) {
headers := jws.NewHeaders()
err := headers.Set(jws.TypeKey, tokenType)
if err != nil {
return nil, fmt.Errorf("failed to set token type: %w", err)
}
return headers, nil
}
// SetIsAdmin sets the "isAdmin" claim in the token
func SetIsAdmin(token jwt.Token, isAdmin bool) error {
// Only set if true
@@ -495,3 +535,37 @@ func SetIsAdmin(token jwt.Token, isAdmin bool) error {
func SetAudienceString(token jwt.Token, audience string) error {
return token.Set(jwt.AudienceKey, audience)
}
// VerifyTokenTypeHeader verifies that the "typ" header in the token matches the expected type
func VerifyTokenTypeHeader(tokenBytes string, expectedTokenType string) error {
// Parse the raw token string purely as a JWS message structure
// We don't need to verify the signature at this stage, just inspect headers.
msg, err := jws.Parse([]byte(tokenBytes))
if err != nil {
return fmt.Errorf("failed to parse token as JWS message: %w", err)
}
// Get the list of signatures attached to the message. Usually just one for JWT.
signatures := msg.Signatures()
if len(signatures) == 0 {
return errors.New("JWS message contains no signatures")
}
protectedHeaders := signatures[0].ProtectedHeaders()
if protectedHeaders == nil {
return fmt.Errorf("JWS signature has no protected headers")
}
// Retrieve the 'typ' header value from the PROTECTED headers.
var typHeaderValue string
err = protectedHeaders.Get(jws.TypeKey, &typHeaderValue)
if err != nil {
return fmt.Errorf("token is missing required protected header '%s'", jws.TypeKey)
}
if !strings.EqualFold(typHeaderValue, expectedTokenType) {
return fmt.Errorf("'%s' header mismatch: expected '%s', got '%s'", jws.TypeKey, expectedTokenType, typHeaderValue)
}
return nil
}