You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

146 lines
4.5 KiB
Go

package main
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"git.runcible.io/learning/pulley/internal/logging"
"github.com/google/uuid"
)
// recoverPanic provides middleware for recovering from panics in the same goroutine that
// executred the recoverPanic middleware.
//
// This means that panics from goroutines created in handler operations will still need
// to be handled separately.Failure to do so will cause the application to exit and bring
// down the server.
func (app *application) recoverPanic(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a deferred function (which will always be run in the event of a panic
// as Go unwinds the stack).
defer func() {
if err := recover(); err != nil {
// If there was a panic, set a "Connection: close" header on the
// response. This acts as a trigger to make Go's HTTP server
// automatically close the current connection after a response has been
// sent.
w.Header().Set("Connection", "close")
app.serverErrorResponse(w, r, fmt.Errorf("%s", http.ErrAbortHandler))
}
}()
next.ServeHTTP(w, r)
})
}
// RequestLoggingMiddleware
// So this was my first crack at a request logging middleware. Originally I intended a LoggerFromContext function would
// extract this prebuilt logger from the context at each layer it is passed to, but on futher examination I think it
// would repeat information that is superfluous to that context of logging. Take for example if I use this in the Database Model
// Service why does it matter what http.method was used in the http.Handler. The only important thing in that context would be
// the Request ID that triggered the Database call.
func (app *application) RquestLoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestId := r.Header.Get("X-Request-Id")
if requestId == "" {
requestId = uuid.New().String()
}
var (
ip = r.RemoteAddr
proto = r.Proto
method = r.Method
uri = r.URL.RequestURI()
)
reqLogger := slog.Default().With(
// TODO figure out if there is more
slog.Group("http",
slog.String("request_id", requestId),
slog.String("ip", ip),
slog.String("proto", proto),
slog.String("method", method),
slog.String("uri", uri),
),
)
ctx := context.WithValue(r.Context(), logging.CtxKeyLogger, reqLogger)
ctx = context.WithValue(ctx, logging.CtxKeyTraceID, requestId)
w.Header().Set("X-Request-Id", requestId)
reqLogger.Info("Recieved a request")
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// See https://blog.questionable.services/article/guide-logging-middleware-go/
// responseWriter is a minimal wrapper for http.ResponseWriter that allows the
// written HTTP status code to be captured for logging.
type responseWriter struct {
http.ResponseWriter
status int
wroteHeader bool
}
func wrapResponseWriter(w http.ResponseWriter) *responseWriter {
return &responseWriter{ResponseWriter: w}
}
func (rw *responseWriter) Status() int {
return rw.status
}
func (rw *responseWriter) WriteHeader(code int) {
if rw.wroteHeader {
return
}
rw.status = code
rw.ResponseWriter.WriteHeader(code)
rw.wroteHeader = true
return
}
func (app *application) BetterRequestLoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestId := r.Header.Get("X-Request-Id")
if requestId == "" {
requestId = uuid.New().String()
}
var (
ip = r.RemoteAddr
proto = r.Proto
method = r.Method
uri = r.URL.RequestURI()
)
reqLogger := slog.Default().With(
slog.String("correlation_id", requestId),
)
httpBaseGroup := func(kv ...any) slog.Attr {
return slog.Group("http",
append([]any{
slog.String("request_id", requestId),
slog.String("ip", ip),
slog.String("proto", proto),
slog.String("method", method),
slog.String("uri", uri),
}, kv...)...,
)
}
ctx := context.WithValue(r.Context(), logging.CtxKeyLogger, reqLogger)
ctx = context.WithValue(ctx, logging.CtxKeyTraceID, requestId)
w.Header().Set("X-Request-Id", requestId)
start := time.Now()
wrapped := wrapResponseWriter(w)
reqLogger.Info("Recieved a new request", slog.String("event", "start"), httpBaseGroup())
next.ServeHTTP(wrapped, r.WithContext(ctx))
reqLogger.Info("Completed request", slog.String("event", "end"), httpBaseGroup(slog.Int("status", wrapped.status), slog.Int64("duration_ms", time.Since(start).Milliseconds())))
})
}