Use context for authentication information

main
Drew Bednar 2 months ago
parent 00ec29d696
commit c5aab74159

@ -84,5 +84,10 @@ func (u *UserService) Authenticate(email, password string) (int, error) {
}
func (u *UserService) Exists(id int) (bool, error) {
return false, nil
var exists bool
stmt := "SELECT EXISTS(SELECT true FROM users WHERE id = ?)"
err := u.DB.QueryRow(stmt, id).Scan(&exists)
return exists, err
}

@ -0,0 +1,5 @@
package server
type contextKey string
const isAuthenticatedKey = contextKey("isAuthenticated")

@ -6,8 +6,6 @@ import (
"runtime/debug"
"strconv"
"strings"
"github.com/alexedwards/scs/v2"
)
// serverError helper writes a log entry at Error level (including the request
@ -66,6 +64,11 @@ func logAuthSuccess(logger *slog.Logger, r *http.Request, email string, userId i
}
// is Authenticated returns true if an authenticated user ID has been set in the session
func isAuthenticated(r *http.Request, sm *scs.SessionManager) bool {
return sm.Exists(r.Context(), "authenticatedUserID")
func isAuthenticated(r *http.Request) bool {
// return sm.Exists(r.Context(), "authenticatedUserID")
isAuthenticated, ok := r.Context().Value(isAuthenticatedKey).(bool)
if !ok {
return false
}
return isAuthenticated
}

@ -1,10 +1,12 @@
package server
import (
"context"
"fmt"
"log/slog"
"net/http"
"git.runcible.io/learning/ratchet/internal/model"
"github.com/alexedwards/scs/v2"
"github.com/justinas/nosurf"
)
@ -110,7 +112,7 @@ func RequireAuthenticationMiddleware(next http.Handler, sm *scs.SessionManager)
// If the user is not authenticated, redirect them to the login page and
// return from the middleware chain so that no subsequent handlers in
// the chain are executed.
if !isAuthenticated(r, sm) {
if !isAuthenticated(r) {
http.Redirect(w, r, "/user/login", http.StatusSeeOther)
return
}
@ -136,3 +138,32 @@ func NoSurfMiddleware(next http.Handler) http.Handler {
return csrfHandler
}
func AuthenticateMiddleware(next http.Handler, sm *scs.SessionManager, userService *model.UserService) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := sm.GetInt(r.Context(), "authenticatedUserID")
if id == 0 {
// no authenticated user
next.ServeHTTP(w, r)
return
}
exists, err := userService.Exists(id)
if err != nil {
serverError(w, r, err)
return
}
// If a matching user is found, we know that the request is
// coming from an authenticated user who exists in our database. We
// create a new copy of the request (with an isAuthenticatedContextKey
// value of true in the request context) and assign it to r.
if exists {
ctx := context.WithValue(r.Context(), isAuthenticatedKey, true)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}

@ -23,21 +23,21 @@ func addRoutes(mux *http.ServeMux,
// resulting in this route requiring an exact match on "/" only
// You can only include one HTTP method in a route pattern if you choose
// GET will match GET & HEAD http request methods
mux.Handle("GET /{$}", sm.LoadAndSave(NoSurfMiddleware(handleHome(logger, tc, sm, snippetService)))) // might be time to swith to github.com/justinas/alice dynamic chain
mux.Handle("GET /snippet/view/{id}", sm.LoadAndSave(NoSurfMiddleware(handleSnippetView(logger, tc, sm, snippetService))))
mux.Handle("GET /snippet/create", sm.LoadAndSave(NoSurfMiddleware(RequireAuthenticationMiddleware(handleSnippetCreateGet(tc, sm), sm))))
mux.Handle("POST /snippet/create", sm.LoadAndSave(NoSurfMiddleware(RequireAuthenticationMiddleware(handleSnippetCreatePost(logger, tc, fd, sm, snippetService), sm))))
mux.Handle("GET /{$}", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleHome(logger, tc, sm, snippetService), sm, userService)))) // might be time to swith to github.com/justinas/alice dynamic chain
mux.Handle("GET /snippet/view/{id}", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleSnippetView(logger, tc, sm, snippetService), sm, userService))))
mux.Handle("GET /snippet/create", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(RequireAuthenticationMiddleware(handleSnippetCreateGet(tc, sm), sm), sm, userService))))
mux.Handle("POST /snippet/create", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(RequireAuthenticationMiddleware(handleSnippetCreatePost(logger, tc, fd, sm, snippetService), sm), sm, userService))))
// mux.Handle("/something", handleSomething(logger, config))
// mux.Handle("/healthz", handleHealthzPlease(logger))
// mux.Handle("/", http.NotFoundHandler())
mux.Handle("GET /user/signup", sm.LoadAndSave(NoSurfMiddleware(handleUserSignupGet(tc, sm))))
mux.Handle("POST /user/signup", sm.LoadAndSave(NoSurfMiddleware(handleUserSignupPost(logger, tc, fd, sm, userService))))
mux.Handle("GET /user/login", sm.LoadAndSave(NoSurfMiddleware(handleUserLoginGet(tc, sm))))
mux.Handle("POST /user/login", sm.LoadAndSave(NoSurfMiddleware(handleUserLoginPost(logger, tc, sm, fd, userService))))
mux.Handle("GET /user/signup", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserSignupGet(tc, sm), sm, userService))))
mux.Handle("POST /user/signup", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserSignupPost(logger, tc, fd, sm, userService), sm, userService))))
mux.Handle("GET /user/login", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserLoginGet(tc, sm), sm, userService))))
mux.Handle("POST /user/login", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(handleUserLoginPost(logger, tc, sm, fd, userService), sm, userService))))
// Requires auth
mux.Handle("POST /user/logout", sm.LoadAndSave(NoSurfMiddleware(RequireAuthenticationMiddleware(handleUserLogoutPost(logger, sm), sm))))
mux.Handle("POST /user/logout", sm.LoadAndSave(NoSurfMiddleware(AuthenticateMiddleware(RequireAuthenticationMiddleware(handleUserLogoutPost(logger, sm), sm), sm, userService))))
return mux
}

@ -32,7 +32,7 @@ type templateData struct {
func newTemplateData(r *http.Request, sm *scs.SessionManager) templateData {
return templateData{CurrentYear: time.Now().Year(),
Flash: sm.PopString(r.Context(), "flash"),
IsAuthenticated: isAuthenticated(r, sm),
IsAuthenticated: isAuthenticated(r),
// added to every page because the form for logout can appear on every page
CSRFToken: nosurf.Token(r),
}

Loading…
Cancel
Save