package main import ( "context" "fmt" "log/slog" "net/http" "time" "git.runcible.io/learning/pulley/internal/logging" "github.com/google/uuid" ) // recoverPanic provides middleware for recovering from panics in the same goroutine that // executred the recoverPanic middleware. // // This means that panics from goroutines created in handler operations will still need // to be handled separately.Failure to do so will cause the application to exit and bring // down the server. func (app *application) recoverPanic(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Create a deferred function (which will always be run in the event of a panic // as Go unwinds the stack). defer func() { if err := recover(); err != nil { // If there was a panic, set a "Connection: close" header on the // response. This acts as a trigger to make Go's HTTP server // automatically close the current connection after a response has been // sent. w.Header().Set("Connection", "close") app.serverErrorResponse(w, r, fmt.Errorf("%s", http.ErrAbortHandler)) } }() next.ServeHTTP(w, r) }) } // RequestLoggingMiddleware // So this was my first crack at a request logging middleware. Originally I intended a LoggerFromContext function would // extract this prebuilt logger from the context at each layer it is passed to, but on futher examination I think it // would repeat information that is superfluous to that context of logging. Take for example if I use this in the Database Model // Service why does it matter what http.method was used in the http.Handler. The only important thing in that context would be // the Request ID that triggered the Database call. func (app *application) RquestLoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestId := r.Header.Get("X-Request-Id") if requestId == "" { requestId = uuid.New().String() } var ( ip = r.RemoteAddr proto = r.Proto method = r.Method uri = r.URL.RequestURI() ) reqLogger := slog.Default().With( // TODO figure out if there is more slog.Group("http", slog.String("request_id", requestId), slog.String("ip", ip), slog.String("proto", proto), slog.String("method", method), slog.String("uri", uri), ), ) ctx := context.WithValue(r.Context(), logging.CtxKeyLogger, reqLogger) ctx = context.WithValue(ctx, logging.CtxKeyTraceID, requestId) w.Header().Set("X-Request-Id", requestId) reqLogger.Info("Recieved a request") next.ServeHTTP(w, r.WithContext(ctx)) }) } // See https://blog.questionable.services/article/guide-logging-middleware-go/ // responseWriter is a minimal wrapper for http.ResponseWriter that allows the // written HTTP status code to be captured for logging. type responseWriter struct { http.ResponseWriter status int wroteHeader bool } func wrapResponseWriter(w http.ResponseWriter) *responseWriter { return &responseWriter{ResponseWriter: w} } func (rw *responseWriter) Status() int { return rw.status } func (rw *responseWriter) WriteHeader(code int) { if rw.wroteHeader { return } rw.status = code rw.ResponseWriter.WriteHeader(code) rw.wroteHeader = true return } func (app *application) BetterRequestLoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestId := r.Header.Get("X-Request-Id") if requestId == "" { requestId = uuid.New().String() } var ( ip = r.RemoteAddr proto = r.Proto method = r.Method uri = r.URL.RequestURI() ) reqLogger := slog.Default().With( slog.String("correlation_id", requestId), ) httpBaseGroup := func(kv ...any) slog.Attr { return slog.Group("http", append([]any{ slog.String("request_id", requestId), slog.String("ip", ip), slog.String("proto", proto), slog.String("method", method), slog.String("uri", uri), }, kv...)..., ) } ctx := context.WithValue(r.Context(), logging.CtxKeyLogger, reqLogger) ctx = context.WithValue(ctx, logging.CtxKeyTraceID, requestId) w.Header().Set("X-Request-Id", requestId) start := time.Now() wrapped := wrapResponseWriter(w) reqLogger.Info("Recieved a new request", slog.String("event", "start"), httpBaseGroup()) next.ServeHTTP(wrapped, r.WithContext(ctx)) reqLogger.Info("Completed request", slog.String("event", "end"), httpBaseGroup(slog.Int("status", wrapped.status), slog.Int64("duration_ms", time.Since(start).Milliseconds()))) }) }