From 12865ca5dd1ad05086721d18b0372809aebce3fa Mon Sep 17 00:00:00 2001 From: Drew Bednar Date: Sun, 9 Feb 2025 11:20:56 -0500 Subject: [PATCH] Require authentication for certain routes --- internal/server/helpers.go | 9 +++++++++ internal/server/middleware.go | 21 +++++++++++++++++++++ internal/server/routes.go | 8 +++++--- internal/server/templates.go | 15 +++++++++------ ui/html/partials/nav.go.tmpl | 9 +++++++-- 5 files changed, 51 insertions(+), 11 deletions(-) diff --git a/internal/server/helpers.go b/internal/server/helpers.go index 4a8136e..c81f883 100644 --- a/internal/server/helpers.go +++ b/internal/server/helpers.go @@ -6,6 +6,8 @@ import ( "runtime/debug" "strconv" "strings" + + "github.com/alexedwards/scs/v2" ) // serverError helper writes a log entry at Error level (including the request @@ -44,6 +46,8 @@ func getClientIP(r *http.Request) string { return strings.Split(r.RemoteAddr, ":")[0] } +// TODO we probably want to distinguish between invalid email and in valid password + func logAuthFailure(logger *slog.Logger, r *http.Request, email string) { logger.Info("authentication attempt failed", slog.String("event_type", "authentication_failure"), @@ -60,3 +64,8 @@ func logAuthSuccess(logger *slog.Logger, r *http.Request, email string, userId i slog.String("ip_address", getClientIP(r)), slog.String("user_agent", r.Header.Get("User-Agent"))) } + +// is Authenticated returns true if an authenticated user ID has been set in the session +func isAuthenticated(r *http.Request, sm *scs.SessionManager) bool { + return sm.Exists(r.Context(), "authenticatedUserID") +} diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 814cbed..e3541b1 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -4,6 +4,8 @@ import ( "fmt" "log/slog" "net/http" + + "github.com/alexedwards/scs/v2" ) // https://owasp.org/www-project-secure-headers/ guidance @@ -101,3 +103,22 @@ func RecoveryMiddleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +func RequireAuthenticationMiddleware(next http.Handler, sm *scs.SessionManager) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // If the user is not authenticated, redirect them to the login page and + // return from the middleware chain so that no subsequent handlers in + // the chain are executed. + if !isAuthenticated(r, sm) { + http.Redirect(w, r, "/user/login", http.StatusSeeOther) + return + } + + // Otherwise set the "Cache-Control: no-store" header so that pages + // require authentication are not stored in the users browser cache (or + // other intermediary cache). + w.Header().Add("Cache-Control", "no-store") + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 54cd8b1..6c0bf40 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -25,8 +25,8 @@ func addRoutes(mux *http.ServeMux, // GET will match GET & HEAD http request methods mux.Handle("GET /{$}", sm.LoadAndSave(handleHome(logger, tc, sm, snippetService))) // might be time to swith to github.com/justinas/alice dynamic chain mux.Handle("GET /snippet/view/{id}", sm.LoadAndSave(handleSnippetView(logger, tc, sm, snippetService))) - mux.Handle("GET /snippet/create", sm.LoadAndSave(handleSnippetCreateGet(tc, sm))) - mux.Handle("POST /snippet/create", sm.LoadAndSave(handleSnippetCreatePost(logger, tc, fd, sm, snippetService))) + mux.Handle("GET /snippet/create", sm.LoadAndSave(RequireAuthenticationMiddleware(handleSnippetCreateGet(tc, sm), sm))) + mux.Handle("POST /snippet/create", sm.LoadAndSave(RequireAuthenticationMiddleware(handleSnippetCreatePost(logger, tc, fd, sm, snippetService), sm))) // mux.Handle("/something", handleSomething(logger, config)) // mux.Handle("/healthz", handleHealthzPlease(logger)) // mux.Handle("/", http.NotFoundHandler()) @@ -35,7 +35,9 @@ func addRoutes(mux *http.ServeMux, mux.Handle("POST /user/signup", sm.LoadAndSave(handleUserSignupPost(logger, tc, fd, sm, userService))) mux.Handle("GET /user/login", sm.LoadAndSave(handleUserLoginGet(tc, sm))) mux.Handle("POST /user/login", sm.LoadAndSave(handleUserLoginPost(logger, tc, sm, fd, userService))) - mux.Handle("POST /user/logout", sm.LoadAndSave(handleUserLogoutPost(logger, sm))) + + // Requires auth + mux.Handle("POST /user/logout", sm.LoadAndSave(RequireAuthenticationMiddleware(handleUserLogoutPost(logger, sm), sm))) return mux } diff --git a/internal/server/templates.go b/internal/server/templates.go index 52ab009..a9223dc 100644 --- a/internal/server/templates.go +++ b/internal/server/templates.go @@ -18,16 +18,19 @@ import ( // At the moment it only contains one field, but we'll add more // to it as the build progresses. type templateData struct { - CurrentYear int - Snippet model.Snippet - Snippets []model.Snippet - Form any - Flash string + CurrentYear int + Snippet model.Snippet + Snippets []model.Snippet + Form any + Flash string + IsAuthenticated bool } // newTemplateData is useful to inject default values. Example CSRF tokens for forms. func newTemplateData(r *http.Request, sm *scs.SessionManager) templateData { - return templateData{CurrentYear: time.Now().Year(), Flash: sm.PopString(r.Context(), "flash")} + return templateData{CurrentYear: time.Now().Year(), + Flash: sm.PopString(r.Context(), "flash"), + IsAuthenticated: isAuthenticated(r, sm)} } // TEMPLATE FUNCTIONS diff --git a/ui/html/partials/nav.go.tmpl b/ui/html/partials/nav.go.tmpl index 4c94406..ed22ab2 100644 --- a/ui/html/partials/nav.go.tmpl +++ b/ui/html/partials/nav.go.tmpl @@ -2,14 +2,19 @@ {{- end}} \ No newline at end of file