diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index c839e42f..30ac1b94 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -63,6 +63,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) { rateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60) // Setup global middleware + r.Use(middleware.HeadMiddleware()) r.Use(middleware.NewCacheControlMiddleware().Add()) r.Use(middleware.NewCorsMiddleware().Add()) r.Use(middleware.NewCspMiddleware().Add()) @@ -111,7 +112,17 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) { srv := &http.Server{ MaxHeaderBytes: 1 << 20, ReadHeaderTimeout: 10 * time.Second, - Handler: r, + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // HEAD requests don't get matched by Gin routes, so we convert them to GET + // middleware.HeadMiddleware will convert them back to HEAD later + if req.Method == http.MethodHead { + req.Method = http.MethodGet + ctx := context.WithValue(req.Context(), middleware.IsHeadRequestCtxKey{}, true) + req = req.WithContext(ctx) + } + + r.ServeHTTP(w, req) + }), } // Set up the listener diff --git a/backend/internal/middleware/head_middleware.go b/backend/internal/middleware/head_middleware.go new file mode 100644 index 00000000..541927f6 --- /dev/null +++ b/backend/internal/middleware/head_middleware.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" +) + +type IsHeadRequestCtxKey struct{} + +type headWriter struct { + gin.ResponseWriter + size int +} + +func (w *headWriter) Write(b []byte) (int, error) { + w.size += len(b) + return w.size, nil +} + +func HeadMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Only process if it's a HEAD request + if c.Request.Context().Value(IsHeadRequestCtxKey{}) != true { + c.Next() + return + } + + // Replace the ResponseWriter with our headWriter to swallow the body + hw := &headWriter{ResponseWriter: c.Writer} + c.Writer = hw + + c.Next() + + c.Writer.Header().Set("Content-Length", strconv.Itoa(hw.size)) + c.Request.Method = http.MethodHead + + } +}