diff --git a/internal/server/handler_test.go b/internal/server/handler_test.go new file mode 100644 index 0000000..fcd9887 --- /dev/null +++ b/internal/server/handler_test.go @@ -0,0 +1,49 @@ +package server + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.runcible.io/learning/ratchet/internal/assert" +) + +func TestPing(t *testing.T) { + // This is essentially an implementation of http.ResponseWriter + // which records the response status code, headers and body instead + // of actually writing them to a HTTP connection. + rr := httptest.NewRecorder() + + // Initialize a dummy request + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + // When called, t.Fatal() will mark the test as failed, log the error, + // and then completely stop execution of the current test + // (or sub-test). + + // Typically you should call t.Fatal() in situations where it doesn’t + // make sense to continue the current test — such as an error during + // a setup step, or where an unexpected error from a Go standard + // library function means you can’t proceed with the test. + t.Fatal(err) + } + ping := PingHandler() + ping.ServeHTTP(rr, r) + + resp := rr.Result() + + assert.Equal(t, resp.StatusCode, http.StatusOK) + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + body = bytes.TrimSpace(body) + + assert.Equal(t, string(body), "OK") + +} diff --git a/internal/server/handlers.go b/internal/server/handlers.go index 9ee098e..1bbd8e0 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -444,3 +444,9 @@ func handleUserLogoutPost(logger *slog.Logger, sm *scs.SessionManager) http.Hand http.Redirect(w, r, "/", http.StatusSeeOther) }) } + +func PingHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + }) +} diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go new file mode 100644 index 0000000..9f51b84 --- /dev/null +++ b/internal/server/middleware_test.go @@ -0,0 +1,77 @@ +package server + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.runcible.io/learning/ratchet/internal/assert" +) + +func TestCommonHeadersMiddleware(t *testing.T) { + rr := httptest.NewRecorder() + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + + // mock http.Handler + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + }) + + // Pass the mock HTTP handler to our commonHeaders middleware. Because + // commonHeaders *returns* a http.Handler we can call its ServeHTTP() + // method, passing in the http.ResponseRecorder and dummy http.Request to + // execute it. + CommonHeaderMiddleware(next).ServeHTTP(rr, r) + + resp := rr.Result() + + // Check that the middleware has correctly set the Content-Security-Policy + // header on the response. + expectedValue := "default-src 'self'; style-src 'self' fonts.googleapis.com; font-src fonts.gstatic.com" + assert.Equal(t, resp.Header.Get("Content-Security-Policy"), expectedValue) + + // Check that the middleware has correctly set the Referrer-Policy + // header on the response. + expectedValue = "origin-when-cross-origin" + assert.Equal(t, resp.Header.Get("Referrer-Policy"), expectedValue) + + // Check that the middleware has correctly set the X-Content-Type-Options + // header on the response. + expectedValue = "nosniff" + assert.Equal(t, resp.Header.Get("X-Content-Type-Options"), expectedValue) + + // Check that the middleware has correctly set the X-Frame-Options header + // on the response. + expectedValue = "deny" + assert.Equal(t, resp.Header.Get("X-Frame-Options"), expectedValue) + + // Check that the middleware has correctly set the X-XSS-Protection header + // on the response + expectedValue = "0" + assert.Equal(t, resp.Header.Get("X-XSS-Protection"), expectedValue) + + // Check that the middleware has correctly set the Server header on the + // response. + expectedValue = "Go" + assert.Equal(t, resp.Header.Get("Server"), expectedValue) + + // Check that the middleware has correctly called the next handler in line + // and the response status code and body are as expected. + assert.Equal(t, resp.StatusCode, http.StatusOK) + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + body = bytes.TrimSpace(body) + + assert.Equal(t, string(body), "OK") +}