mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-12 00:03:03 +03:00
fix: hash the refresh token in the DB (security) (#379)
This commit is contained in:
committed by
GitHub
parent
26b2de4f00
commit
8c963818bb
@@ -145,121 +145,133 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
|
|||||||
return isAllowedToAuthorize
|
return isAllowedToAuthorize
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (string, string, string, int, error) {
|
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
|
||||||
if grantType == "authorization_code" {
|
switch grantType {
|
||||||
var client model.OidcClient
|
case "authorization_code":
|
||||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
return s.createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier)
|
||||||
return "", "", "", 0, err
|
case "refresh_token":
|
||||||
}
|
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(refreshToken, clientID, clientSecret)
|
||||||
|
return "", accessToken, newRefreshToken, exp, err
|
||||||
|
default:
|
||||||
|
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Verify the client secret if the client is not public
|
func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
||||||
if !client.IsPublic {
|
var client model.OidcClient
|
||||||
if clientID == "" || clientSecret == "" {
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
|
return "", "", "", 0, err
|
||||||
}
|
|
||||||
|
|
||||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
|
||||||
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
|
|
||||||
if client.IsPublic || client.PkceEnabled {
|
|
||||||
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
|
||||||
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
|
||||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a refresh token
|
|
||||||
refreshToken, err := s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
|
||||||
|
|
||||||
s.db.Delete(&authorizationCodeMetaData)
|
|
||||||
|
|
||||||
return idToken, accessToken, refreshToken, 3600, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if grantType == "refresh_token" {
|
// Verify the client secret if the client is not public
|
||||||
if refreshToken == "" {
|
if !client.IsPublic {
|
||||||
return "", "", "", 0, &common.OidcMissingRefreshTokenError{}
|
if clientID == "" || clientSecret == "" {
|
||||||
|
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the client to check if it's public
|
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||||
var client model.OidcClient
|
|
||||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
|
||||||
return "", "", "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the client secret if the client is not public
|
|
||||||
if !client.IsPublic {
|
|
||||||
if clientID == "" || clientSecret == "" {
|
|
||||||
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify refresh token
|
|
||||||
var storedRefreshToken model.OidcRefreshToken
|
|
||||||
if err := s.db.Preload("User").Where("token = ? AND expires_at > ?", refreshToken, datatype.DateTime(time.Now())).First(&storedRefreshToken).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
|
||||||
}
|
|
||||||
return "", "", "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify that the refresh token belongs to the provided client
|
|
||||||
if storedRefreshToken.ClientID != clientID {
|
|
||||||
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a new access token
|
|
||||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new refresh token and invalidate the old one
|
|
||||||
newRefreshToken, err := s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the used refresh token
|
|
||||||
s.db.Delete(&storedRefreshToken)
|
|
||||||
|
|
||||||
return "", accessToken, newRefreshToken, 3600, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
|
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||||
|
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
|
||||||
|
if client.IsPublic || client.PkceEnabled {
|
||||||
|
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
||||||
|
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||||
|
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a refresh token
|
||||||
|
refreshToken, err = s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||||
|
|
||||||
|
s.db.Delete(&authorizationCodeMetaData)
|
||||||
|
|
||||||
|
return idToken, accessToken, refreshToken, 3600, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
|
||||||
|
if refreshToken == "" {
|
||||||
|
return "", "", 0, &common.OidcMissingRefreshTokenError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the client to check if it's public
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the client secret if the client is not public
|
||||||
|
if !client.IsPublic {
|
||||||
|
if clientID == "" || clientSecret == "" {
|
||||||
|
return "", "", 0, &common.OidcMissingClientCredentialsError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, &common.OidcClientSecretInvalidError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify refresh token
|
||||||
|
var storedRefreshToken model.OidcRefreshToken
|
||||||
|
err = s.db.Preload("User").
|
||||||
|
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||||
|
First(&storedRefreshToken).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
||||||
|
}
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the refresh token belongs to the provided client
|
||||||
|
if storedRefreshToken.ClientID != clientID {
|
||||||
|
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new access token
|
||||||
|
accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new refresh token and invalidate the old one
|
||||||
|
newRefreshToken, err = s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the used refresh token
|
||||||
|
s.db.Delete(&storedRefreshToken)
|
||||||
|
|
||||||
|
return accessToken, newRefreshToken, 3600, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
|
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
|
||||||
@@ -630,22 +642,26 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
|
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
|
||||||
randomString, err := utils.GenerateRandomAlphanumericString(40)
|
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken := model.OidcRefreshToken{
|
// Compute the hash of the refresh token to store in the DB
|
||||||
|
// Refresh tokens are pretty long already, so a "simple" SHA-256 hash is enough
|
||||||
|
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
|
||||||
|
|
||||||
|
m := model.OidcRefreshToken{
|
||||||
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
|
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
|
||||||
Token: randomString,
|
Token: refreshTokenHash,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Scope: scope,
|
Scope: scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.db.Create(&refreshToken).Error; err != nil {
|
if err := s.db.Create(&m).Error; err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return randomString, nil
|
return refreshToken, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ func (s *TestService) SeedDatabase() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
refreshToken := model.OidcRefreshToken{
|
refreshToken := model.OidcRefreshToken{
|
||||||
Token: "ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo",
|
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
|
||||||
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
|
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
|
||||||
Scope: "openid profile email",
|
Scope: "openid profile email",
|
||||||
UserID: users[0].ID,
|
UserID: users[0].ID,
|
||||||
|
|||||||
Reference in New Issue
Block a user