2024-10-28 18:11:54 +01:00
package middleware
2024-08-23 17:04:19 +02:00
import (
"errors"
2024-10-28 18:11:54 +01:00
"fmt"
2025-02-05 18:08:01 +01:00
"net/http"
"strings"
2024-08-23 17:04:19 +02:00
"github.com/gin-gonic/gin"
2024-10-28 18:11:54 +01:00
"github.com/gin-gonic/gin/binding"
2024-08-23 17:04:19 +02:00
"github.com/go-playground/validator/v10"
2025-02-05 18:08:01 +01:00
"github.com/pocket-id/pocket-id/backend/internal/common"
2024-08-23 17:04:19 +02:00
"gorm.io/gorm"
)
2024-10-28 18:11:54 +01:00
type ErrorHandlerMiddleware struct { }
2024-08-23 17:04:19 +02:00
2024-10-28 18:11:54 +01:00
func NewErrorHandlerMiddleware ( ) * ErrorHandlerMiddleware {
return & ErrorHandlerMiddleware { }
}
func ( m * ErrorHandlerMiddleware ) Add ( ) gin . HandlerFunc {
return func ( c * gin . Context ) {
c . Next ( )
for _ , err := range c . Errors {
// Check for record not found errors
if errors . Is ( err , gorm . ErrRecordNotFound ) {
errorResponse ( c , http . StatusNotFound , "Record not found" )
return
}
// Check for validation errors
var validationErrors validator . ValidationErrors
if errors . As ( err , & validationErrors ) {
message := handleValidationError ( validationErrors )
errorResponse ( c , http . StatusBadRequest , message )
return
}
// Check for slice validation errors
var sliceValidationErrors binding . SliceValidationError
if errors . As ( err , & sliceValidationErrors ) {
if errors . As ( sliceValidationErrors [ 0 ] , & validationErrors ) {
message := handleValidationError ( validationErrors )
errorResponse ( c , http . StatusBadRequest , message )
return
}
}
2024-08-23 17:04:19 +02:00
2024-10-28 18:11:54 +01:00
var appErr common . AppError
if errors . As ( err , & appErr ) {
errorResponse ( c , appErr . HttpStatusCode ( ) , appErr . Error ( ) )
return
}
2024-08-23 17:04:19 +02:00
2024-10-28 18:11:54 +01:00
c . JSON ( http . StatusInternalServerError , gin . H { "error" : "Something went wrong" } )
}
2024-08-23 17:04:19 +02:00
}
2024-10-28 18:11:54 +01:00
}
2024-08-23 17:04:19 +02:00
2024-10-28 18:11:54 +01:00
func errorResponse ( c * gin . Context , statusCode int , message string ) {
// Capitalize the first letter of the message
message = strings . ToUpper ( message [ : 1 ] ) + message [ 1 : ]
c . JSON ( statusCode , gin . H { "error" : message } )
2024-08-23 17:04:19 +02:00
}
func handleValidationError ( validationErrors validator . ValidationErrors ) string {
var errorMessages [ ] string
for _ , ve := range validationErrors {
fieldName := ve . Field ( )
var errorMessage string
switch ve . Tag ( ) {
case "required" :
errorMessage = fmt . Sprintf ( "%s is required" , fieldName )
case "email" :
errorMessage = fmt . Sprintf ( "%s must be a valid email address" , fieldName )
case "username" :
2024-09-03 22:35:18 +02:00
errorMessage = fmt . Sprintf ( "%s must only contain lowercase letters, numbers, underscores, dots, hyphens, and '@' symbols and not start or end with a special character" , fieldName )
2024-08-23 17:04:19 +02:00
case "url" :
errorMessage = fmt . Sprintf ( "%s must be a valid URL" , fieldName )
case "min" :
errorMessage = fmt . Sprintf ( "%s must be at least %s characters long" , fieldName , ve . Param ( ) )
case "max" :
errorMessage = fmt . Sprintf ( "%s must be at most %s characters long" , fieldName , ve . Param ( ) )
default :
errorMessage = fmt . Sprintf ( "%s is invalid" , fieldName )
}
2024-08-24 00:49:08 +02:00
2024-08-23 17:04:19 +02:00
errorMessages = append ( errorMessages , errorMessage )
}
// Join all the error messages into a single string
combinedErrors := strings . Join ( errorMessages , ", " )
return combinedErrors
}