From c5aab74159cb7b727fbffbf0d4ddcb6bceb50a18 Mon Sep 17 00:00:00 2001 From: Drew Bednar Date: Sat, 15 Feb 2025 09:36:52 -0500 Subject: [PATCH] Use context for authentication information --- internal/model/user.go | 7 ++++++- internal/server/context.go | 5 +++++ internal/server/helpers.go | 11 +++++++---- internal/server/middleware.go | 33 ++++++++++++++++++++++++++++++++- internal/server/routes.go | 18 +++++++++--------- internal/server/templates.go | 2 +- 6 files changed, 60 insertions(+), 16 deletions(-) create mode 100644 internal/server/context.go diff --git a/internal/model/user.go b/internal/model/user.go index 5d97159..821fb29 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -84,5 +84,10 @@ func (u *UserService) Authenticate(email, password string) (int, error) { } func (u *UserService) Exists(id int) (bool, error) { - return false, nil + var exists bool + + stmt := "SELECT EXISTS(SELECT true FROM users WHERE id = ?)" + + err := u.DB.QueryRow(stmt, id).Scan(&exists) + return exists, err } diff --git a/internal/server/context.go b/internal/server/context.go new file mode 100644 index 0000000..ba06ba4 --- /dev/null +++ b/internal/server/context.go @@ -0,0 +1,5 @@ +package server + +type contextKey string + +const isAuthenticatedKey = contextKey("isAuthenticated") diff --git a/internal/server/helpers.go b/internal/server/helpers.go index c81f883..21835d2 100644 --- a/internal/server/helpers.go +++ b/internal/server/helpers.go @@ -6,8 +6,6 @@ import ( "runtime/debug" "strconv" "strings" - - "github.com/alexedwards/scs/v2" ) // serverError helper writes a log entry at Error level (including the request @@ -66,6 +64,11 @@ func logAuthSuccess(logger *slog.Logger, r *http.Request, email string, userId i } // is Authenticated returns true if an authenticated user ID has been set in the session -func isAuthenticated(r *http.Request, sm *scs.SessionManager) bool { - return sm.Exists(r.Context(), "authenticatedUserID") +func isAuthenticated(r *http.Request) bool { + // return sm.Exists(r.Context(), "authenticatedUserID") + isAuthenticated, ok := r.Context().Value(isAuthenticatedKey).(bool) + if !ok { + return false + } + return isAuthenticated } diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 57c2c2d..e98a9f6 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -1,10 +1,12 @@ package server import ( + "context" "fmt" "log/slog" "net/http" + "git.runcible.io/learning/ratchet/internal/model" "github.com/alexedwards/scs/v2" "github.com/justinas/nosurf" ) @@ -110,7 +112,7 @@ func RequireAuthenticationMiddleware(next http.Handler, sm *scs.SessionManager) // If the user is not authenticated, redirect them to the login page and // return from the middleware chain so that no subsequent handlers in // the chain are executed. - if !isAuthenticated(r, sm) { + if !isAuthenticated(r) { http.Redirect(w, r, "/user/login", http.StatusSeeOther) return } @@ -136,3 +138,32 @@ func NoSurfMiddleware(next http.Handler) http.Handler { return csrfHandler } + +func AuthenticateMiddleware(next http.Handler, sm *scs.SessionManager, userService *model.UserService) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + id := sm.GetInt(r.Context(), "authenticatedUserID") + if id == 0 { + // no authenticated user + next.ServeHTTP(w, r) + return + } + + exists, err := userService.Exists(id) + if err != nil { + serverError(w, r, err) + return + } + // If a matching user is found, we know that the request is + // coming from an authenticated user who exists in our database. We + // create a new copy of the request (with an isAuthenticatedContextKey + // value of true in the request context) and assign it to r. + if exists { + ctx := context.WithValue(r.Context(), isAuthenticatedKey, true) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) + }) + +} diff --git a/internal/server/routes.go b/internal/server/routes.go index f25e6dc..9b011ab 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -23,21 +23,21 @@ func addRoutes(mux *http.ServeMux, // resulting in this route requiring an exact match on "/" only // You can only include one HTTP method in a route pattern if you choose // GET will match GET & HEAD http request methods - mux.Handle("GET /{$}", sm.LoadAndSave(NoSurfMiddleware(handleHome(logger, tc, sm, snippetService)))) // might be time to swith to github.com/justinas/alice dynamic chain - mux.Handle("GET /snippet/view/{id}", sm.LoadAndSave(NoSurfMiddleware(handleSnippetView(logger, tc, sm, snippetService)))) - mux.Handle("GET /snippet/create", sm.LoadAndSave(NoSurfMiddleware(RequireAuthenticationMiddleware(handleSnippetCreateGet(tc, sm), sm)))) - mux.Handle("POST /snippet/create", sm.LoadAndSave(NoSurfMiddleware(RequireAuthenticationMiddleware(handleSnippetCreatePost(logger, tc, fd, sm, snippetService), sm)))) + mux.Handle("GET /{$}", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleHome(logger, tc, sm, snippetService), sm, userService)))) // might be time to swith to github.com/justinas/alice dynamic chain + mux.Handle("GET /snippet/view/{id}", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleSnippetView(logger, tc, sm, snippetService), sm, userService)))) + mux.Handle("GET /snippet/create", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(RequireAuthenticationMiddleware(handleSnippetCreateGet(tc, sm), sm), sm, userService)))) + mux.Handle("POST /snippet/create", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(RequireAuthenticationMiddleware(handleSnippetCreatePost(logger, tc, fd, sm, snippetService), sm), sm, userService)))) // mux.Handle("/something", handleSomething(logger, config)) // mux.Handle("/healthz", handleHealthzPlease(logger)) // mux.Handle("/", http.NotFoundHandler()) - mux.Handle("GET /user/signup", sm.LoadAndSave(NoSurfMiddleware(handleUserSignupGet(tc, sm)))) - mux.Handle("POST /user/signup", sm.LoadAndSave(NoSurfMiddleware(handleUserSignupPost(logger, tc, fd, sm, userService)))) - mux.Handle("GET /user/login", sm.LoadAndSave(NoSurfMiddleware(handleUserLoginGet(tc, sm)))) - mux.Handle("POST /user/login", sm.LoadAndSave(NoSurfMiddleware(handleUserLoginPost(logger, tc, sm, fd, userService)))) + mux.Handle("GET /user/signup", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserSignupGet(tc, sm), sm, userService)))) + mux.Handle("POST /user/signup", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserSignupPost(logger, tc, fd, sm, userService), sm, userService)))) + mux.Handle("GET /user/login", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserLoginGet(tc, sm), sm, userService)))) + mux.Handle("POST /user/login", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserLoginPost(logger, tc, sm, fd, userService), sm, userService)))) // Requires auth - mux.Handle("POST /user/logout", sm.LoadAndSave(NoSurfMiddleware(RequireAuthenticationMiddleware(handleUserLogoutPost(logger, sm), sm)))) + mux.Handle("POST /user/logout", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(RequireAuthenticationMiddleware(handleUserLogoutPost(logger, sm), sm), sm, userService)))) return mux } diff --git a/internal/server/templates.go b/internal/server/templates.go index 8650812..23fb945 100644 --- a/internal/server/templates.go +++ b/internal/server/templates.go @@ -32,7 +32,7 @@ type templateData struct { func newTemplateData(r *http.Request, sm *scs.SessionManager) templateData { return templateData{CurrentYear: time.Now().Year(), Flash: sm.PopString(r.Context(), "flash"), - IsAuthenticated: isAuthenticated(r, sm), + IsAuthenticated: isAuthenticated(r), // added to every page because the form for logout can appear on every page CSRFToken: nosurf.Token(r), }