mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-04 18:00:38 +03:00
124 lines
3.5 KiB
Go
124 lines
3.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin/binding"
|
|
"github.com/go-playground/validator/v10"
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type ErrorHandlerMiddleware struct{}
|
|
|
|
func NewErrorHandlerMiddleware() *ErrorHandlerMiddleware {
|
|
return &ErrorHandlerMiddleware{}
|
|
}
|
|
|
|
func (m *ErrorHandlerMiddleware) Add() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
c.Next()
|
|
for _, err := range c.Errors {
|
|
// Check for record not found errors
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
errorResponse(c, http.StatusNotFound, "Record not found")
|
|
return
|
|
}
|
|
|
|
// Check for validation errors
|
|
var validationErrors validator.ValidationErrors
|
|
if errors.As(err, &validationErrors) {
|
|
message := handleValidationError(validationErrors)
|
|
errorResponse(c, http.StatusBadRequest, message)
|
|
return
|
|
}
|
|
|
|
// Check for slice validation errors
|
|
svErr, ok := errors.AsType[binding.SliceValidationError](err)
|
|
if ok {
|
|
if errors.As(svErr[0], &validationErrors) {
|
|
message := handleValidationError(validationErrors)
|
|
errorResponse(c, http.StatusBadRequest, message)
|
|
return
|
|
}
|
|
}
|
|
|
|
// AppError with description
|
|
appDescErr, ok := errors.AsType[common.AppErrorDescription](err)
|
|
if ok {
|
|
errorResponseWithDescription(c, appDescErr.HttpStatusCode(), appDescErr.Error(), appDescErr.Description())
|
|
return
|
|
}
|
|
|
|
// AppError (without description)
|
|
appErr, ok := errors.AsType[common.AppError](err)
|
|
if ok {
|
|
errorResponse(c, appErr.HttpStatusCode(), appErr.Error())
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusInternalServerError, errorResponseBody{
|
|
Error: "Something went wrong",
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
type errorResponseBody struct {
|
|
Error string `json:"error"`
|
|
ErrorDescription string `json:"error_description,omitempty"`
|
|
}
|
|
|
|
func errorResponse(c *gin.Context, statusCode int, message string) {
|
|
// Capitalize the first letter of the message
|
|
message = strings.ToUpper(message[:1]) + message[1:]
|
|
c.JSON(statusCode, errorResponseBody{
|
|
Error: message,
|
|
})
|
|
}
|
|
|
|
func errorResponseWithDescription(c *gin.Context, statusCode int, message string, description string) {
|
|
// Capitalize the first letter of the message
|
|
message = strings.ToUpper(message[:1]) + message[1:]
|
|
c.JSON(statusCode, errorResponseBody{
|
|
Error: message,
|
|
ErrorDescription: description,
|
|
})
|
|
}
|
|
|
|
func handleValidationError(validationErrors validator.ValidationErrors) string {
|
|
var errorMessages []string
|
|
|
|
for _, ve := range validationErrors {
|
|
fieldName := ve.Field()
|
|
var errorMessage string
|
|
switch ve.Tag() {
|
|
case "required":
|
|
errorMessage = fmt.Sprintf("%s is required", fieldName)
|
|
case "email":
|
|
errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName)
|
|
case "username":
|
|
errorMessage = fmt.Sprintf("%s must only contain letters, numbers, underscores, dots, hyphens, and '@' symbols and not start or end with a special character", fieldName)
|
|
case "url":
|
|
errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName)
|
|
case "min":
|
|
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
|
|
case "max":
|
|
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
|
|
default:
|
|
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
|
|
}
|
|
|
|
errorMessages = append(errorMessages, errorMessage)
|
|
}
|
|
|
|
// Join all the error messages into a single string
|
|
combinedErrors := strings.Join(errorMessages, ", ")
|
|
|
|
return combinedErrors
|
|
}
|