Compare commits
	
		
			37 Commits 
		
	
	
		
			drew/lets-
			...
			main
		
	
	@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					kind: pipeline
 | 
				
			||||||
 | 
					type: docker
 | 
				
			||||||
 | 
					name: CI Test Pipeline
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					steps:
 | 
				
			||||||
 | 
					- name: Unit Tests
 | 
				
			||||||
 | 
					  image: golang:1.23
 | 
				
			||||||
 | 
					  privileged: true
 | 
				
			||||||
 | 
					  commands:
 | 
				
			||||||
 | 
					    - go test -v ./...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					trigger:
 | 
				
			||||||
 | 
					  event:
 | 
				
			||||||
 | 
					    - pull_request
 | 
				
			||||||
 | 
					    - push
 | 
				
			||||||
@ -1 +1,113 @@
 | 
				
			|||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// func testingLogger() *slog.Logger {
 | 
				
			||||||
 | 
					// 	return slog.New(slog.NewTextHandler(io.Discard, nil))
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WITHOUT testutils_test.go helpers
 | 
				
			||||||
 | 
					// func TestPingIntegration(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 	rs := server.NewRatchetApp(testingLogger(), nil, nil, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 	// We then use the httptest.NewTLSServer() function to create a new test
 | 
				
			||||||
 | 
					// 	// server, passing in the value returned by our app.routes() method as the
 | 
				
			||||||
 | 
					// 	// handler for the server. This starts up a HTTPS server which listens on a
 | 
				
			||||||
 | 
					// 	// randomly-chosen port of your local machine for the duration of the test.
 | 
				
			||||||
 | 
					// 	// Notice that we defer a call to ts.Close() so that the server is shutdown
 | 
				
			||||||
 | 
					// 	// when the test finishes.
 | 
				
			||||||
 | 
					// 	ts := httptest.NewTLSServer(rs)
 | 
				
			||||||
 | 
					// 	defer ts.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 	resp, err := ts.Client().Get(ts.URL + "/ping")
 | 
				
			||||||
 | 
					// 	if err != nil {
 | 
				
			||||||
 | 
					// 		t.Fatal(err)
 | 
				
			||||||
 | 
					// 	}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 	assert.Equal(t, resp.StatusCode, http.StatusOK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 	defer resp.Body.Close()
 | 
				
			||||||
 | 
					// 	body, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
					// 	if err != nil {
 | 
				
			||||||
 | 
					// 		t.Fatal(err)
 | 
				
			||||||
 | 
					// 	}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 	body = bytes.TrimSpace(body)
 | 
				
			||||||
 | 
					// 	assert.Equal(t, string(body), "OK")
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestPingIntegration(t *testing.T) {
 | 
				
			||||||
 | 
						// Tests marked using t.Parallel() will be run in parallel with — and only with — other parallel tests.
 | 
				
			||||||
 | 
						// By default, the maximum number of tests that will be run simultaneously is the current value of
 | 
				
			||||||
 | 
						// GOMAXPROCS. You can override this by setting a specific value via the -parallel flag.
 | 
				
			||||||
 | 
						t.Parallel()
 | 
				
			||||||
 | 
						app := newTestApplication(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ts := newTestServer(t, app.Routes())
 | 
				
			||||||
 | 
						defer ts.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						code, _, body := ts.get(t, "/ping")
 | 
				
			||||||
 | 
						assert.Equal(t, code, http.StatusOK)
 | 
				
			||||||
 | 
						assert.Equal(t, body, "OK")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSnippetView(t *testing.T) {
 | 
				
			||||||
 | 
						t.Parallel()
 | 
				
			||||||
 | 
						app := newTestApplication(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ts := newTestServer(t, app.Routes())
 | 
				
			||||||
 | 
						defer ts.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name     string
 | 
				
			||||||
 | 
							urlPath  string
 | 
				
			||||||
 | 
							wantCode int
 | 
				
			||||||
 | 
							wantBody string
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:     "Valid ID",
 | 
				
			||||||
 | 
								urlPath:  "/snippet/view/1",
 | 
				
			||||||
 | 
								wantCode: 200,
 | 
				
			||||||
 | 
								wantBody: "Hello golang mocking",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:     "Nonexistent ID",
 | 
				
			||||||
 | 
								urlPath:  "/snippet/view/2",
 | 
				
			||||||
 | 
								wantCode: 404,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:     "Negative ID",
 | 
				
			||||||
 | 
								urlPath:  "/snippet/view/-1",
 | 
				
			||||||
 | 
								wantCode: 404,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:     "Decimal ID",
 | 
				
			||||||
 | 
								urlPath:  "/snippet/view/1.23",
 | 
				
			||||||
 | 
								wantCode: 404,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:     "string ID",
 | 
				
			||||||
 | 
								urlPath:  "/snippet/view/foo",
 | 
				
			||||||
 | 
								wantCode: 404,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:     "emptry ID",
 | 
				
			||||||
 | 
								urlPath:  "/snippet/view/",
 | 
				
			||||||
 | 
								wantCode: 404,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, test := range tests {
 | 
				
			||||||
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								code, _, body := ts.get(t, test.urlPath)
 | 
				
			||||||
 | 
								assert.Equal(t, code, test.wantCode)
 | 
				
			||||||
 | 
								assert.StringContains(t, body, test.wantBody)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,74 @@
 | 
				
			|||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"log/slog"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/cookiejar"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/model/mock"
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/server"
 | 
				
			||||||
 | 
						"github.com/alexedwards/scs/v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a newTestApplication helper which returns an instance of our
 | 
				
			||||||
 | 
					// application struct containing mocked dependencies.
 | 
				
			||||||
 | 
					func newTestApplication(t *testing.T) *server.RatchetApp {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						//tc, err := server.InitTemplateCache()
 | 
				
			||||||
 | 
						tc, err := server.InitFSTemplateCache()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sessionManager := scs.New()
 | 
				
			||||||
 | 
						sessionManager.Lifetime = 12 * time.Hour
 | 
				
			||||||
 | 
						sessionManager.Cookie.Secure = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rs := server.NewRatchetApp(slog.New(slog.NewTextHandler(io.Discard, nil)), tc, &mock.SnippetService{}, &mock.UserService{}, sessionManager)
 | 
				
			||||||
 | 
						return rs
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// create out own test server with additional receiver functions for ease of testing
 | 
				
			||||||
 | 
					type testServer struct {
 | 
				
			||||||
 | 
						*httptest.Server
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newTestServer(t *testing.T, h http.Handler) *testServer {
 | 
				
			||||||
 | 
						ts := httptest.NewTLSServer(h)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						jar, err := cookiejar.New(nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Any response cookies will be stored and sent on subsequent requests
 | 
				
			||||||
 | 
						ts.Client().Jar = jar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Disable redirect-following for test server client by setting custom
 | 
				
			||||||
 | 
						// CheckRedirect function. Called whenever 3xx response. By returning a
 | 
				
			||||||
 | 
						// http.ErrUseLastResponse error it forces the client to immediately return
 | 
				
			||||||
 | 
						// the received response
 | 
				
			||||||
 | 
						ts.Client().CheckRedirect = func(req *http.Request, via []*http.Request) error {
 | 
				
			||||||
 | 
							return http.ErrUseLastResponse
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &testServer{ts}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (ts *testServer) get(t *testing.T, urlPath string) (int, http.Header, string) {
 | 
				
			||||||
 | 
						resp, err := ts.Client().Get(ts.URL + urlPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer resp.Body.Close()
 | 
				
			||||||
 | 
						body, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return resp.StatusCode, resp.Header, string(body)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,2 +1,37 @@
 | 
				
			|||||||
 | 
					github.com/alexedwards/scs/sqlite3store v0.0.0-20240316134038-7e11d57e8885 h1:+DCxWg/ojncqS+TGAuRUoV7OfG/S4doh0pcpAwEcow0=
 | 
				
			||||||
 | 
					github.com/alexedwards/scs/sqlite3store v0.0.0-20240316134038-7e11d57e8885/go.mod h1:Iyk7S76cxGaiEX/mSYmTZzYehp4KfyylcLaV3OnToss=
 | 
				
			||||||
 | 
					github.com/alexedwards/scs/v2 v2.8.0 h1:h31yUYoycPuL0zt14c0gd+oqxfRwIj6SOjHdKRZxhEw=
 | 
				
			||||||
 | 
					github.com/alexedwards/scs/v2 v2.8.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8=
 | 
				
			||||||
 | 
					github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
				
			||||||
 | 
					github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
				
			||||||
 | 
					github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
				
			||||||
 | 
					github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
 | 
				
			||||||
 | 
					github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
 | 
				
			||||||
 | 
					github.com/go-playground/form/v4 v4.2.1 h1:HjdRDKO0fftVMU5epjPW2SOREcZ6/wLUzEobqUGJuPw=
 | 
				
			||||||
 | 
					github.com/go-playground/form/v4 v4.2.1/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
 | 
				
			||||||
 | 
					github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y=
 | 
				
			||||||
 | 
					github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks=
 | 
				
			||||||
 | 
					github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
 | 
				
			||||||
 | 
					github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
 | 
				
			||||||
 | 
					github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
 | 
				
			||||||
 | 
					github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
 | 
				
			||||||
 | 
					github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
 | 
				
			||||||
 | 
					github.com/justinas/nosurf v1.1.1 h1:92Aw44hjSK4MxJeMSyDa7jwuI9GR2J/JCQiaKvXXSlk=
 | 
				
			||||||
 | 
					github.com/justinas/nosurf v1.1.1/go.mod h1:ALpWdSbuNGy2lZWtyXdjkYv4edL23oSEgfBT1gPJ5BQ=
 | 
				
			||||||
 | 
					github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
 | 
				
			||||||
 | 
					github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
 | 
				
			||||||
 | 
					github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
 | 
				
			||||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
 | 
					github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
 | 
				
			||||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
 | 
					github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
 | 
				
			||||||
 | 
					github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 | 
				
			||||||
 | 
					github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 | 
				
			||||||
 | 
					github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
				
			||||||
 | 
					github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 | 
				
			||||||
 | 
					github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
 | 
				
			||||||
 | 
					github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
 | 
				
			||||||
 | 
					go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
 | 
				
			||||||
 | 
					go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
 | 
				
			||||||
 | 
					golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
 | 
				
			||||||
 | 
					golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
 | 
				
			||||||
 | 
					gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 | 
				
			||||||
 | 
					gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					INSERT INTO users VALUES(
 | 
				
			||||||
 | 
					    1337,
 | 
				
			||||||
 | 
					    'tester',
 | 
				
			||||||
 | 
					    'tester@example.com',
 | 
				
			||||||
 | 
					    /* thisisinsecure */
 | 
				
			||||||
 | 
					    '$2a$12$M51w5lWkveAOhwoanoCxO.hJe3s1m8qJuCzbzdETt0SThjpq4BPRq',
 | 
				
			||||||
 | 
					    '2025-02-25 18:58:44',
 | 
				
			||||||
 | 
					    '2025-02-25 18:58:44'
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					package integration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/migrations"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const testerPasswd = "thisisinsecure"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newTestDB(t *testing.T) *sql.DB {
 | 
				
			||||||
 | 
						dbFile, err := os.CreateTemp("", "ratchet-*.db")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						db, err := sql.Open("sqlite3", dbFile.Name())
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							db.Close()
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = migrations.Migrate(db)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							db.Close()
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						script, err := os.ReadFile("./testdata/seed.sql")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							db.Close()
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = db.Exec(string(script))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							db.Close()
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Cleanup(func() {
 | 
				
			||||||
 | 
							db.Close()
 | 
				
			||||||
 | 
							err := os.Remove(dbFile.Name())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return db
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,51 @@
 | 
				
			|||||||
 | 
					package integration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/assert"
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/model"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestUserModelExists(t *testing.T) {
 | 
				
			||||||
 | 
						// Skip the test if the "-short" flag is provided when running the test.
 | 
				
			||||||
 | 
						if testing.Short() {
 | 
				
			||||||
 | 
							t.Skip("models: skipping model integration test")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name   string
 | 
				
			||||||
 | 
							userID int
 | 
				
			||||||
 | 
							want   bool
 | 
				
			||||||
 | 
						}{{
 | 
				
			||||||
 | 
							name:   "Valid ID",
 | 
				
			||||||
 | 
							userID: 1337,
 | 
				
			||||||
 | 
							want:   true,
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:   "Zero ID",
 | 
				
			||||||
 | 
								userID: 0,
 | 
				
			||||||
 | 
								want:   false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:   "Zero ID",
 | 
				
			||||||
 | 
								userID: 2,
 | 
				
			||||||
 | 
								want:   false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, test := range tests {
 | 
				
			||||||
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								db := newTestDB(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								userService := model.UserService{db}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								exists, err := userService.Exists(test.userID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								assert.Equal(t, exists, test.want)
 | 
				
			||||||
 | 
								assert.NilError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					package mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/model"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var mockSnippet = model.Snippet{
 | 
				
			||||||
 | 
						ID:        1,
 | 
				
			||||||
 | 
						Title:     sql.NullString{String: "Hello golang mocking", Valid: true},
 | 
				
			||||||
 | 
						Content:   sql.NullString{String: "Hello golang mocking", Valid: true},
 | 
				
			||||||
 | 
						CreatedAt: time.Now(),
 | 
				
			||||||
 | 
						UpdatedAt: time.Now(),
 | 
				
			||||||
 | 
						ExpiresAt: time.Now(),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SnippetService struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SnippetService) Insert(title, content string, expiresAt int) (int, error) {
 | 
				
			||||||
 | 
						return 2, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SnippetService) Get(id int) (model.Snippet, error) {
 | 
				
			||||||
 | 
						if id == mockSnippet.ID {
 | 
				
			||||||
 | 
							return mockSnippet, nil
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return model.Snippet{}, model.ErrNoRecord
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SnippetService) Lastest() ([]model.Snippet, error) {
 | 
				
			||||||
 | 
						return []model.Snippet{mockSnippet}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,34 @@
 | 
				
			|||||||
 | 
					package mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "git.runcible.io/learning/ratchet/internal/model"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var MockEmail = "drew@fake.com"
 | 
				
			||||||
 | 
					var MockPassword = "thisisinsecure"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type UserService struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *UserService) Insert(name, email, password string) (int, error) {
 | 
				
			||||||
 | 
						switch email {
 | 
				
			||||||
 | 
						case "dupe@example.com":
 | 
				
			||||||
 | 
							return 0, model.ErrDuplicateEmail
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return 1, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *UserService) Authenticate(email, password string) (int, error) {
 | 
				
			||||||
 | 
						if email == MockEmail && password == MockPassword {
 | 
				
			||||||
 | 
							return 1, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return 0, model.ErrInvalidCredentials
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *UserService) Exists(id int) (bool, error) {
 | 
				
			||||||
 | 
						switch id {
 | 
				
			||||||
 | 
						case 1:
 | 
				
			||||||
 | 
							return true, nil
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return false, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,75 +1,99 @@
 | 
				
			|||||||
package model
 | 
					package model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"log/slog"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.runcible.io/learning/ratchet/internal/apperror"
 | 
						"github.com/mattn/go-sqlite3"
 | 
				
			||||||
 | 
						"golang.org/x/crypto/bcrypt"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type User struct {
 | 
					type User struct {
 | 
				
			||||||
	ID int `json:"id"`
 | 
						ID             int
 | 
				
			||||||
 | 
						Name           string
 | 
				
			||||||
	// User prefered name and email
 | 
						Email          string
 | 
				
			||||||
	Name  string `json:"name"`
 | 
						HashedPassword []byte
 | 
				
			||||||
	Email string `json:"email"`
 | 
						CreatedAt      time.Time
 | 
				
			||||||
 | 
						UpdatedAt      time.Time
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Randomly generated API key for use with the API
 | 
					type UserServiceInterface interface {
 | 
				
			||||||
	// "-" omits the key from serialization
 | 
						Insert(name, email, password string) (int, error)
 | 
				
			||||||
	APIKey string `json:"-"`
 | 
						Authenticate(email, password string) (int, error)
 | 
				
			||||||
 | 
						Exists(id int) (bool, error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Timestamps for user creatation and last update
 | 
					// TODD add logger to service
 | 
				
			||||||
	CreatedAt time.Time `json:"createdAt"`
 | 
					type UserService struct {
 | 
				
			||||||
	UpdatedAt time.Time `json:"updatedAt"`
 | 
						DB *sql.DB
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// List of associated Oauth authentication Objects
 | 
					func (u *UserService) Insert(name, email, password string) (int, error) {
 | 
				
			||||||
	// Not yet implemented
 | 
						hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 12)
 | 
				
			||||||
	// Auths []*Auth `json:"auths"`
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return 0, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (u *User) Validate() error {
 | 
						stmt := `INSERT INTO users (name, email, hashed_password) 
 | 
				
			||||||
	if u.Name == "" {
 | 
						VALUES (?,?,?)`
 | 
				
			||||||
		return apperror.Errorf(apperror.EINVALID, "User name required.")
 | 
					
 | 
				
			||||||
 | 
						result, err := u.DB.Exec(stmt, name, email, string(hashedPassword))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							slog.Debug(fmt.Sprintf("Error encounters on insert: %s", err.Error()))
 | 
				
			||||||
 | 
							// This is a assertion that err is of type sqlite3.Error. If it is ok is true.
 | 
				
			||||||
 | 
							if serr, ok := err.(sqlite3.Error); ok {
 | 
				
			||||||
 | 
								slog.Debug("Error is sqlite3.Error type.")
 | 
				
			||||||
 | 
								if serr.ExtendedCode == sqlite3.ErrConstraintUnique {
 | 
				
			||||||
 | 
									slog.Debug("Error is a unique contraint violation.")
 | 
				
			||||||
 | 
									return 0, ErrDuplicateEmail
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
	return nil
 | 
							}
 | 
				
			||||||
 | 
							return 0, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UserService represents a service for managing users.
 | 
						lastId, err := result.LastInsertId()
 | 
				
			||||||
type UserService interface {
 | 
						if err != nil {
 | 
				
			||||||
	// Retrieves a user by ID along with their associated auth objects
 | 
							slog.Debug("An error occured when retrieving insert result id.")
 | 
				
			||||||
	// Returns ENOTFOUND if user does not exist.
 | 
							return 0, err
 | 
				
			||||||
	FindUserByID(ctx context.Context, id int) (*User, error)
 | 
						}
 | 
				
			||||||
 | 
						slog.Debug(fmt.Sprintf("Inserted new user. User pk: %d", int(lastId)))
 | 
				
			||||||
	// Retrieves a list of users by filter. Also returns total count of matching users
 | 
						return int(lastId), nil
 | 
				
			||||||
	// which may differ from retruned results if filter.Limit is specified.
 | 
					}
 | 
				
			||||||
	FindUsers(ctc context.Context, filter UserFilter) ([]*User, int, error)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Creates a new use. This is only used for testing since users are typically
 | 
					func (u *UserService) Authenticate(email, password string) (int, error) {
 | 
				
			||||||
	// cretaed during the OAuth creation process in the AuthService.CreateAuth().
 | 
						var id int
 | 
				
			||||||
	CreateUser(ctx context.Context, user *User) error
 | 
						var hashedPassword []byte
 | 
				
			||||||
 | 
						stmt := `SELECT id, hashed_password FROM users WHERE email == ?`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := u.DB.QueryRow(stmt, email).Scan(&id, &hashedPassword)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if errors.Is(err, sql.ErrNoRows) {
 | 
				
			||||||
 | 
								return 0, ErrInvalidCredentials
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return 0, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Updates a user object. Returns EUNAUTHORIZED if current user is not
 | 
						err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(password))
 | 
				
			||||||
	// the user that is being updated. Returns ENOTFOUND if the user does not exist.
 | 
						if err != nil {
 | 
				
			||||||
	UpdateUser(ctx context.Context, id int, upd UserUpdate) (*User, error)
 | 
							if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
 | 
				
			||||||
 | 
								return 0, ErrInvalidCredentials
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return 0, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Permanently deletes a user and all owned application resources. Returns EUNAUTHORIZED
 | 
						return id, nil
 | 
				
			||||||
	// if current user is not the user being deleted. Returns ENOTFOUND if the user
 | 
					 | 
				
			||||||
	// does not exist.
 | 
					 | 
				
			||||||
	DeleteUser(ctx context.Context, id int) error
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UserFilter respresents a filter passed to FindUsers().
 | 
					func (u *UserService) Exists(id int) (bool, error) {
 | 
				
			||||||
type UserFilter struct {
 | 
						var exists bool
 | 
				
			||||||
	ID     *int    `json:"id"`
 | 
					 | 
				
			||||||
	Email  *string `json:"email"`
 | 
					 | 
				
			||||||
	APIKey *string `json:"apiKey"`
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Restrict to subset of results
 | 
						stmt := "SELECT EXISTS(SELECT true FROM users WHERE id = ?)"
 | 
				
			||||||
	Offset int `json:"offset"`
 | 
					 | 
				
			||||||
	Limit  int `json:"limit"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
type UserUpdate struct {
 | 
						err := u.DB.QueryRow(stmt, id).Scan(&exists)
 | 
				
			||||||
	Name  *string `json:"name"`
 | 
						return exists, err
 | 
				
			||||||
	Email *string `json:"email"`
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,75 @@
 | 
				
			|||||||
 | 
					package model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/apperror"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Userwtf struct {
 | 
				
			||||||
 | 
						ID int `json:"id"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// User prefered name and email
 | 
				
			||||||
 | 
						Name  string `json:"name"`
 | 
				
			||||||
 | 
						Email string `json:"email"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Randomly generated API key for use with the API
 | 
				
			||||||
 | 
						// "-" omits the key from serialization
 | 
				
			||||||
 | 
						APIKey string `json:"-"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Timestamps for user creatation and last update
 | 
				
			||||||
 | 
						CreatedAt time.Time `json:"createdAt"`
 | 
				
			||||||
 | 
						UpdatedAt time.Time `json:"updatedAt"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// List of associated Oauth authentication Objects
 | 
				
			||||||
 | 
						// Not yet implemented
 | 
				
			||||||
 | 
						// Auths []*Auth `json:"auths"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *Userwtf) Validate() error {
 | 
				
			||||||
 | 
						if u.Name == "" {
 | 
				
			||||||
 | 
							return apperror.Errorf(apperror.EINVALID, "User name required.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UserService represents a service for managing users.
 | 
				
			||||||
 | 
					type UserServicewtf interface {
 | 
				
			||||||
 | 
						// Retrieves a user by ID along with their associated auth objects
 | 
				
			||||||
 | 
						// Returns ENOTFOUND if user does not exist.
 | 
				
			||||||
 | 
						FindUserByID(ctx context.Context, id int) (*Userwtf, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Retrieves a list of users by filter. Also returns total count of matching users
 | 
				
			||||||
 | 
						// which may differ from retruned results if filter.Limit is specified.
 | 
				
			||||||
 | 
						FindUsers(ctc context.Context, filter UserFilter) ([]*Userwtf, int, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Creates a new use. This is only used for testing since users are typically
 | 
				
			||||||
 | 
						// cretaed during the OAuth creation process in the AuthService.CreateAuth().
 | 
				
			||||||
 | 
						CreateUser(ctx context.Context, user *Userwtf) error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Updates a user object. Returns EUNAUTHORIZED if current user is not
 | 
				
			||||||
 | 
						// the user that is being updated. Returns ENOTFOUND if the user does not exist.
 | 
				
			||||||
 | 
						UpdateUser(ctx context.Context, id int, upd UserUpdate) (*Userwtf, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Permanently deletes a user and all owned application resources. Returns EUNAUTHORIZED
 | 
				
			||||||
 | 
						// if current user is not the user being deleted. Returns ENOTFOUND if the user
 | 
				
			||||||
 | 
						// does not exist.
 | 
				
			||||||
 | 
						DeleteUser(ctx context.Context, id int) error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UserFilter respresents a filter passed to FindUsers().
 | 
				
			||||||
 | 
					type UserFilter struct {
 | 
				
			||||||
 | 
						ID     *int    `json:"id"`
 | 
				
			||||||
 | 
						Email  *string `json:"email"`
 | 
				
			||||||
 | 
						APIKey *string `json:"apiKey"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Restrict to subset of results
 | 
				
			||||||
 | 
						Offset int `json:"offset"`
 | 
				
			||||||
 | 
						Limit  int `json:"limit"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type UserUpdate struct {
 | 
				
			||||||
 | 
						Name  *string `json:"name"`
 | 
				
			||||||
 | 
						Email *string `json:"email"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type contextKey string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const isAuthenticatedKey = contextKey("isAuthenticated")
 | 
				
			||||||
@ -0,0 +1,66 @@
 | 
				
			|||||||
 | 
					// test
 | 
				
			||||||
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/validator"
 | 
				
			||||||
 | 
						"github.com/go-playground/form/v4"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Define a snippetCreateForm struct to represent the form data and validation
 | 
				
			||||||
 | 
					// errors for the form fields. Note that all the struct fields are deliberately
 | 
				
			||||||
 | 
					// exported (i.e. start with a capital letter). This is because struct fields
 | 
				
			||||||
 | 
					// must be exported in order to be read by the html/template package when
 | 
				
			||||||
 | 
					// rendering the template.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Remove the explicit FieldErrors struct field and instead embed the Validator
 | 
				
			||||||
 | 
					// struct. Embedding this means that our snippetCreateForm "inherits" all the
 | 
				
			||||||
 | 
					// fields and methods of our Validator struct (including the FieldErrors field).
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// MOVING TO go-playground/form
 | 
				
			||||||
 | 
					// Update our snippetCreateForm struct to include struct tags which tell the
 | 
				
			||||||
 | 
					// decoder how to map HTML form values into the different struct fields. So, for
 | 
				
			||||||
 | 
					// example, here we're telling the decoder to store the value from the HTML form
 | 
				
			||||||
 | 
					// input with the name "title" in the Title field. The struct tag `form:"-"`
 | 
				
			||||||
 | 
					// tells the decoder to completely ignore a field during decoding.
 | 
				
			||||||
 | 
					type snippetCreateForm struct {
 | 
				
			||||||
 | 
						Title               string `form:"title"`
 | 
				
			||||||
 | 
						Content             string `form:"content"`
 | 
				
			||||||
 | 
						Expires             int    `form:"expires"`
 | 
				
			||||||
 | 
						validator.Validator `form:"-"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type userSignupForm struct {
 | 
				
			||||||
 | 
						Name                string `form:"name"`
 | 
				
			||||||
 | 
						Email               string `form:"email"`
 | 
				
			||||||
 | 
						Password            string `form:"password"`
 | 
				
			||||||
 | 
						validator.Validator `form:"-"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type userLoginForm struct {
 | 
				
			||||||
 | 
						Email               string `form:"email"`
 | 
				
			||||||
 | 
						Password            string `form:"password"`
 | 
				
			||||||
 | 
						validator.Validator `form:"-"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func decodePostForm(r *http.Request, fd *form.Decoder, dst any) error {
 | 
				
			||||||
 | 
						err := r.ParseForm()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = fd.Decode(dst, r.PostForm)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							var invalidDecoderError *form.InvalidDecoderError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if errors.As(err, &invalidDecoderError) {
 | 
				
			||||||
 | 
								// if called in the handler, recovery middleware
 | 
				
			||||||
 | 
								// will log and send 500 response back
 | 
				
			||||||
 | 
								panic(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,77 @@
 | 
				
			|||||||
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCommonHeadersMiddleware(t *testing.T) {
 | 
				
			||||||
 | 
						rr := httptest.NewRecorder()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r, err := http.NewRequest(http.MethodGet, "/", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// mock http.Handler
 | 
				
			||||||
 | 
						next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							w.Write([]byte("OK"))
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Pass the mock HTTP handler to our commonHeaders middleware. Because
 | 
				
			||||||
 | 
						// commonHeaders *returns* a http.Handler we can call its ServeHTTP()
 | 
				
			||||||
 | 
						// method, passing in the http.ResponseRecorder and dummy http.Request to
 | 
				
			||||||
 | 
						// execute it.
 | 
				
			||||||
 | 
						CommonHeaderMiddleware(next).ServeHTTP(rr, r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp := rr.Result()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the middleware has correctly set the Content-Security-Policy
 | 
				
			||||||
 | 
						// header on the response.
 | 
				
			||||||
 | 
						expectedValue := "default-src 'self'; style-src 'self' fonts.googleapis.com; font-src fonts.gstatic.com"
 | 
				
			||||||
 | 
						assert.Equal(t, resp.Header.Get("Content-Security-Policy"), expectedValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the middleware has correctly set the Referrer-Policy
 | 
				
			||||||
 | 
						// header on the response.
 | 
				
			||||||
 | 
						expectedValue = "origin-when-cross-origin"
 | 
				
			||||||
 | 
						assert.Equal(t, resp.Header.Get("Referrer-Policy"), expectedValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the middleware has correctly set the X-Content-Type-Options
 | 
				
			||||||
 | 
						// header on the response.
 | 
				
			||||||
 | 
						expectedValue = "nosniff"
 | 
				
			||||||
 | 
						assert.Equal(t, resp.Header.Get("X-Content-Type-Options"), expectedValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the middleware has correctly set the X-Frame-Options header
 | 
				
			||||||
 | 
						// on the response.
 | 
				
			||||||
 | 
						expectedValue = "deny"
 | 
				
			||||||
 | 
						assert.Equal(t, resp.Header.Get("X-Frame-Options"), expectedValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the middleware has correctly set the X-XSS-Protection header
 | 
				
			||||||
 | 
						// on the response
 | 
				
			||||||
 | 
						expectedValue = "0"
 | 
				
			||||||
 | 
						assert.Equal(t, resp.Header.Get("X-XSS-Protection"), expectedValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the middleware has correctly set the Server header on the
 | 
				
			||||||
 | 
						// response.
 | 
				
			||||||
 | 
						expectedValue = "Go"
 | 
				
			||||||
 | 
						assert.Equal(t, resp.Header.Get("Server"), expectedValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the middleware has correctly called the next handler in line
 | 
				
			||||||
 | 
						// and the response status code and body are as expected.
 | 
				
			||||||
 | 
						assert.Equal(t, resp.StatusCode, http.StatusOK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						body = bytes.TrimSpace(body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, string(body), "OK")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,29 +1,55 @@
 | 
				
			|||||||
package server
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"log/slog"
 | 
					 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.runcible.io/learning/ratchet/internal/model"
 | 
						"git.runcible.io/learning/ratchet/ui"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func addRoutes(mux *http.ServeMux,
 | 
					func addBaseMiddleware(app *RatchetApp, next http.Handler, requireAuth bool) http.Handler {
 | 
				
			||||||
	logger *slog.Logger,
 | 
						var h http.Handler
 | 
				
			||||||
	tc *TemplateCache,
 | 
						h = next
 | 
				
			||||||
	db *sql.DB,
 | 
						if requireAuth {
 | 
				
			||||||
	snippetService *model.SnippetService) http.Handler {
 | 
							h = RequireAuthenticationMiddleware(h, app.sessionManager)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h = AuthenticateMiddleware(h, app.sessionManager, app.userService)
 | 
				
			||||||
 | 
						h = NoSurfMiddleware(h)
 | 
				
			||||||
 | 
						h = app.sessionManager.LoadAndSave(h)
 | 
				
			||||||
 | 
						return h
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (a *RatchetApp) Routes() http.Handler {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO implement middleware that disables directory listings
 | 
				
			||||||
 | 
						// This line was superceded by using the embedded filesystem
 | 
				
			||||||
 | 
						// fileServer := http.FileServer(http.Dir("./ui/static/"))
 | 
				
			||||||
 | 
						router := http.NewServeMux()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Subtree pattern for static assets
 | 
				
			||||||
 | 
						// This line was superceded by using the embedded filesystem
 | 
				
			||||||
 | 
						// router.Handle("GET /static/", http.StripPrefix("/static/", fileServer))
 | 
				
			||||||
 | 
						router.Handle("GET /static/", CacheHeaders(http.FileServerFS(ui.Files)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						router.Handle("GET /ping", PingHandler())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// /{$} is used to prevent subtree path patterns from acting like a wildcard
 | 
						// /{$} is used to prevent subtree path patterns from acting like a wildcard
 | 
				
			||||||
	// resulting in this route requiring an exact match on "/" only
 | 
						// resulting in this route requiring an exact match on "/" only
 | 
				
			||||||
	// You can only include one HTTP method in a route pattern if you choose
 | 
						// You can only include one HTTP method in a route pattern if you choose
 | 
				
			||||||
	// GET will match GET & HEAD http request methods
 | 
						// GET will match GET & HEAD http request methods
 | 
				
			||||||
	mux.Handle("GET /{$}", handleHome(logger, tc, snippetService))
 | 
						router.Handle("GET /{$}", addBaseMiddleware(a, handleHome(a.logger, a.templateCache, a.sessionManager, a.snippetService), false)) // might be time to swith to github.com/justinas/alice dynamic chain
 | 
				
			||||||
	mux.Handle("GET /snippet/view/{id}", handleSnippetView(logger, tc, snippetService))
 | 
						router.Handle("GET /snippet/view/{id}", addBaseMiddleware(a, handleSnippetView(a.logger, a.templateCache, a.sessionManager, a.snippetService), false))
 | 
				
			||||||
	mux.Handle("GET /snippet/create", handleSnippetCreateGet())
 | 
						router.Handle("GET /snippet/create", addBaseMiddleware(a, handleSnippetCreateGet(a.templateCache, a.sessionManager), true))
 | 
				
			||||||
	mux.Handle("POST /snippet/create", handleSnippetCreatePost(logger, tc, snippetService))
 | 
						router.Handle("POST /snippet/create", addBaseMiddleware(a, handleSnippetCreatePost(a.logger, a.templateCache, a.formDecoder, a.sessionManager, a.snippetService), true))
 | 
				
			||||||
	// mux.Handle("/something", handleSomething(logger, config))
 | 
						// mux.Handle("/something", handleSomething(logger, config))
 | 
				
			||||||
	// mux.Handle("/healthz", handleHealthzPlease(logger))
 | 
						// mux.Handle("/healthz", handleHealthzPlease(logger))
 | 
				
			||||||
	// mux.Handle("/", http.NotFoundHandler())
 | 
						// mux.Handle("/", http.NotFoundHandler())
 | 
				
			||||||
	return mux
 | 
					
 | 
				
			||||||
 | 
						router.Handle("GET /user/signup", addBaseMiddleware(a, handleUserSignupGet(a.templateCache, a.sessionManager), false))
 | 
				
			||||||
 | 
						router.Handle("POST /user/signup", addBaseMiddleware(a, handleUserSignupPost(a.logger, a.templateCache, a.formDecoder, a.sessionManager, a.userService), false))
 | 
				
			||||||
 | 
						router.Handle("GET /user/login", addBaseMiddleware(a, handleUserLoginGet(a.templateCache, a.sessionManager), false))
 | 
				
			||||||
 | 
						router.Handle("POST /user/login", addBaseMiddleware(a, handleUserLoginPost(a.logger, a.templateCache, a.sessionManager, a.formDecoder, a.userService), false))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Requires auth
 | 
				
			||||||
 | 
						router.Handle("POST /user/logout", addBaseMiddleware(a, handleUserLogoutPost(a.logger, a.sessionManager), true))
 | 
				
			||||||
 | 
						return RecoveryMiddleware(RequestLoggingMiddleware(CommonHeaderMiddleware(router), a.logger))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,41 +1,32 @@
 | 
				
			|||||||
package server
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"log/slog"
 | 
						"log/slog"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.runcible.io/learning/ratchet/internal/model"
 | 
						"git.runcible.io/learning/ratchet/internal/model"
 | 
				
			||||||
 | 
						"github.com/alexedwards/scs/v2"
 | 
				
			||||||
 | 
						"github.com/go-playground/form/v4"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type RatchetServer struct {
 | 
					type RatchetApp struct {
 | 
				
			||||||
	http.Handler
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	logger        *slog.Logger
 | 
						logger        *slog.Logger
 | 
				
			||||||
	templateCache *TemplateCache
 | 
						templateCache *TemplateCache
 | 
				
			||||||
	//Services used by HTTP routes
 | 
						//Services used by HTTP routes
 | 
				
			||||||
	snippetService *model.SnippetService
 | 
						snippetService model.SnippetServiceInterface
 | 
				
			||||||
	UserService    model.UserService
 | 
						userService    model.UserServiceInterface
 | 
				
			||||||
 | 
						formDecoder    *form.Decoder
 | 
				
			||||||
 | 
						sessionManager *scs.SessionManager
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewRatchetServer(logger *slog.Logger, tc *TemplateCache, db *sql.DB) *RatchetServer {
 | 
					// TODO this function presents some challenges because it both instantiates new data objects
 | 
				
			||||||
	rs := new(RatchetServer)
 | 
					// and configures route / middleware setup
 | 
				
			||||||
 | 
					func NewRatchetApp(logger *slog.Logger, tc *TemplateCache, snippetService model.SnippetServiceInterface, userService model.UserServiceInterface, sm *scs.SessionManager) *RatchetApp {
 | 
				
			||||||
 | 
						rs := new(RatchetApp)
 | 
				
			||||||
	rs.logger = logger
 | 
						rs.logger = logger
 | 
				
			||||||
	rs.snippetService = &model.SnippetService{DB: db}
 | 
						rs.snippetService = snippetService
 | 
				
			||||||
 | 
						rs.userService = userService
 | 
				
			||||||
 | 
						rs.formDecoder = form.NewDecoder()
 | 
				
			||||||
	rs.templateCache = tc
 | 
						rs.templateCache = tc
 | 
				
			||||||
	// TODO implement middleware that disables directory listings
 | 
						rs.sessionManager = sm
 | 
				
			||||||
	fileServer := http.FileServer(http.Dir("./ui/static/"))
 | 
					 | 
				
			||||||
	router := http.NewServeMux()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Subtree pattern for static assets
 | 
					 | 
				
			||||||
	router.Handle("GET /static/", http.StripPrefix("/static/", fileServer))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Mux Router implements the Handler interface. AKA it has a ServeHTTP receiver.
 | 
					 | 
				
			||||||
	// SEE we can really clean things up by moving this into routes.go and handlers.go
 | 
					 | 
				
			||||||
	wrappedMux := addRoutes(router, rs.logger, rs.templateCache, db, rs.snippetService)
 | 
					 | 
				
			||||||
	rs.Handler = CommonHeaderMiddleware(wrappedMux)
 | 
					 | 
				
			||||||
	rs.Handler = RequestLoggingMiddleware(rs.Handler, logger)
 | 
					 | 
				
			||||||
	rs.Handler = RecoveryMiddleware(rs.Handler)
 | 
					 | 
				
			||||||
	return rs
 | 
						return rs
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,49 @@
 | 
				
			|||||||
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.runcible.io/learning/ratchet/internal/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestHumanDate(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Table driven test
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name string
 | 
				
			||||||
 | 
							tm   time.Time
 | 
				
			||||||
 | 
							want string
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "UTC",
 | 
				
			||||||
 | 
								tm:   time.Date(2077, time.April, 12, 23, 0, 0, 0, time.UTC),
 | 
				
			||||||
 | 
								want: "12 Apr 2077 at 23:00",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "UTC",
 | 
				
			||||||
 | 
								tm:   time.Date(2025, time.March, 3, 2, 31, 0, 0, time.UTC),
 | 
				
			||||||
 | 
								want: "03 Mar 2025 at 02:31",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "Empty",
 | 
				
			||||||
 | 
								tm:   time.Time{},
 | 
				
			||||||
 | 
								want: "",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							// CET is one hour ahead of UTC but we print in UTC
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "CET",
 | 
				
			||||||
 | 
								tm:   time.Date(2024, 3, 17, 10, 15, 0, 0, time.FixedZone("CET", 1*60*60)),
 | 
				
			||||||
 | 
								want: "17 Mar 2024 at 09:15",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, test := range tests {
 | 
				
			||||||
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								got := humanDate(test.tm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								assert.Equal(t, got, test.want)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					// See also https://www.alexedwards.net/blog/validation-snippets-for-go
 | 
				
			||||||
 | 
					// for more validation snippets
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// A NonFieldError for example would be "Your email or password is incorrect".
 | 
				
			||||||
 | 
					// more secure because it does not leak which field was in error. Used in the login
 | 
				
			||||||
 | 
					// form
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package validator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
						"slices"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"unicode/utf8"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Validator struct {
 | 
				
			||||||
 | 
						NonFieldErrors []string
 | 
				
			||||||
 | 
						FieldErrors    map[string]string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Valid() returns true if the FieldErrors map doesn't contain any entries.
 | 
				
			||||||
 | 
					func (v *Validator) Valid() bool {
 | 
				
			||||||
 | 
						return len(v.FieldErrors) == 0 && len(v.NonFieldErrors) == 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (v *Validator) AddNonFieldError(message string) {
 | 
				
			||||||
 | 
						v.NonFieldErrors = append(v.NonFieldErrors, message)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AddFieldError() adds an error message to the FieldErrors map (so long as no
 | 
				
			||||||
 | 
					// entry already exists for the given key).
 | 
				
			||||||
 | 
					func (v *Validator) AddFieldError(key, message string) {
 | 
				
			||||||
 | 
						if v.FieldErrors == nil {
 | 
				
			||||||
 | 
							v.FieldErrors = make(map[string]string)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, exists := v.FieldErrors[key]; !exists {
 | 
				
			||||||
 | 
							v.FieldErrors[key] = message
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CheckField() adds an error message to the FieldErrors map only if a
 | 
				
			||||||
 | 
					// validation check is not 'ok'.
 | 
				
			||||||
 | 
					func (v *Validator) CheckField(ok bool, key, message string) {
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							v.AddFieldError(key, message)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NotBlank() returns true if a value is not an empty string.
 | 
				
			||||||
 | 
					func NotBlank(value string) bool {
 | 
				
			||||||
 | 
						return strings.TrimSpace(value) != ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MinChars() returns true if the value contains equal to or greater than n characters
 | 
				
			||||||
 | 
					func MinChars(value string, n int) bool {
 | 
				
			||||||
 | 
						return utf8.RuneCountInString(value) >= n
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MaxChars() returns true if a value contains no more than n characters.
 | 
				
			||||||
 | 
					func MaxChars(value string, n int) bool {
 | 
				
			||||||
 | 
						return utf8.RuneCountInString(value) <= n
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PermittedValue() returns true if a value is in a list of specific permitted
 | 
				
			||||||
 | 
					// values.
 | 
				
			||||||
 | 
					func PermittedValue[T comparable](value T, permittedValues ...T) bool {
 | 
				
			||||||
 | 
						return slices.Contains(permittedValues, value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Use the regexp.MustCompile() function to parse a regular expression pattern
 | 
				
			||||||
 | 
					// for sanity checking the format of an email address. This returns a pointer to
 | 
				
			||||||
 | 
					// a 'compiled' regexp.Regexp type, or panics in the event of an error. Parsing
 | 
				
			||||||
 | 
					// this pattern once at startup and storing the compiled *regexp.Regexp in a
 | 
				
			||||||
 | 
					// variable is more performant than re-parsing the pattern each time we need it.
 | 
				
			||||||
 | 
					// This pattern is recommended by the W3C and Web Hypertext Application Technology Working Group for validating email addresses
 | 
				
			||||||
 | 
					// https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address
 | 
				
			||||||
 | 
					// Because the EmailRX regexp pattern is written as an interpreted string literal, we need to double-escape special characters in the regexp with \\ for it to work correctly
 | 
				
			||||||
 | 
					var EmailRX = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Matches() returns true if a value matches a provided compiled regular
 | 
				
			||||||
 | 
					// expression pattern.
 | 
				
			||||||
 | 
					func Matches(value string, rx *regexp.Regexp) bool {
 | 
				
			||||||
 | 
						return rx.MatchString(value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					DROP TABLE IF EXISTS sessions;
 | 
				
			||||||
@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					CREATE TABLE sessions (
 | 
				
			||||||
 | 
						token TEXT PRIMARY KEY,
 | 
				
			||||||
 | 
						data BLOB NOT NULL,
 | 
				
			||||||
 | 
						expiry REAL NOT NULL
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CREATE INDEX sessions_expiry_idx ON sessions(expiry);
 | 
				
			||||||
@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					DROP TABLE IF EXISTS users;
 | 
				
			||||||
@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					CREATE TABLE users (
 | 
				
			||||||
 | 
					    id INTEGER PRIMARY KEY AUTOINCREMENT,
 | 
				
			||||||
 | 
					    name TEXT NOT NULL,
 | 
				
			||||||
 | 
					    email TEXT NOT NULL UNIQUE,
 | 
				
			||||||
 | 
					    hashed_password TEXT NOT NULL,
 | 
				
			||||||
 | 
					    created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
				
			||||||
 | 
					    updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					-- Add a trigger to keep timestamp updated.
 | 
				
			||||||
 | 
					CREATE TRIGGER users_update_timestamp
 | 
				
			||||||
 | 
					AFTER UPDATE ON users
 | 
				
			||||||
 | 
					FOR EACH ROW
 | 
				
			||||||
 | 
					BEGIN
 | 
				
			||||||
 | 
					    UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = OLD.id;
 | 
				
			||||||
 | 
					END;
 | 
				
			||||||
@ -0,0 +1,48 @@
 | 
				
			|||||||
 | 
					package migrations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"embed"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/golang-migrate/migrate/v4"
 | 
				
			||||||
 | 
						"github.com/golang-migrate/migrate/v4/database/sqlite3"
 | 
				
			||||||
 | 
						_ "github.com/golang-migrate/migrate/v4/source/file"
 | 
				
			||||||
 | 
						"github.com/golang-migrate/migrate/v4/source/iofs"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//go:embed *.sql
 | 
				
			||||||
 | 
					var migrationFiles embed.FS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Migrate(db *sql.DB) error {
 | 
				
			||||||
 | 
						// Create a database driver for the specific database type
 | 
				
			||||||
 | 
						driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to create database driver: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create an IFS source from the embedded files
 | 
				
			||||||
 | 
						source, err := iofs.New(migrationFiles, ".")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to create migration source: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create a new migrate instance
 | 
				
			||||||
 | 
						m, err := migrate.NewWithInstance("iofs", source, "sqlite3", driver)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to create migrate instance: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Run migrations
 | 
				
			||||||
 | 
						if err := m.Up(); err != nil && err != migrate.ErrNoChange {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to run migrations: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// set WAL mode
 | 
				
			||||||
 | 
						_, err = db.Exec("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to set wall mode: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					package ui
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "embed"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The below is actually a comment directive. Comment directives must be
 | 
				
			||||||
 | 
					// placed immediately above the variable used. The path provided is relative
 | 
				
			||||||
 | 
					// to the .go file it is in. You can only embed on global variables at a package
 | 
				
			||||||
 | 
					// level. Paths cannot contain . or .. or begin with / So you are effectively
 | 
				
			||||||
 | 
					// restricted to embedding files within the same directory as the .go file.
 | 
				
			||||||
 | 
					// You can provide multiple passes "static/css" "static/img" "static/js" to help with
 | 
				
			||||||
 | 
					// avoiding shipping things like a PostCSS(tailwind) css file.
 | 
				
			||||||
 | 
					// Can also us static/css/*.css. "all:static" will embed . and _ files too.
 | 
				
			||||||
 | 
					// Lastly the embedded file system is always rooted in the directory that contains
 | 
				
			||||||
 | 
					// the embed directive. So in this example the root is in our ui dir.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//go:embed "static" "html"
 | 
				
			||||||
 | 
					var Files embed.FS
 | 
				
			||||||
@ -0,0 +1,39 @@
 | 
				
			|||||||
 | 
					{{define "title"}}Create a New Snippet{{end}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					{{define "main"}}
 | 
				
			||||||
 | 
					<form action='/snippet/create' method='POST'>
 | 
				
			||||||
 | 
					    <!-- Include the CSRF token -->
 | 
				
			||||||
 | 
					    <input type='hidden' name='csrf_token' value='{{.CSRFToken}}'>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Title:</label>
 | 
				
			||||||
 | 
					        <!-- Use the `with` action to render the value of .Form.FieldErrors.title
 | 
				
			||||||
 | 
					        if it is not empty. -->
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.title}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <input type='text' name='title' value="{{.Form.Title}}">
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Content:</label>
 | 
				
			||||||
 | 
					        <!-- Likewise render the value of .Form.FieldErrors.content if it is not
 | 
				
			||||||
 | 
					        empty. -->
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.content}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <textarea name='content'>{{.Form.Content}}</textarea>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Delete in:</label>
 | 
				
			||||||
 | 
					        <!-- And render the value of .Form.FieldErrors.expires if it is not empty. -->
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.expires}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <input type='radio' name='expires' value='365' {{if (eq .Form.Expires 365)}}checked{{end}}> One Year
 | 
				
			||||||
 | 
					        <input type='radio' name='expires' value='7' {{if (eq .Form.Expires 7)}}checked{{end}}> One Week
 | 
				
			||||||
 | 
					        <input type='radio' name='expires' value='1' {{if (eq .Form.Expires 1)}}checked{{end}}> One Day
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <input type='submit' value='Publish snippet'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					</form>
 | 
				
			||||||
 | 
					{{end}}
 | 
				
			||||||
@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					{{define "title"}}Login{{end}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					{{define "main"}}
 | 
				
			||||||
 | 
					<form action='/user/login' method='POST' novalidate>
 | 
				
			||||||
 | 
					    <!-- Include the CSRF token -->
 | 
				
			||||||
 | 
					    <input type='hidden' name='csrf_token' value='{{.CSRFToken}}'>
 | 
				
			||||||
 | 
					    <!-- Notice that here we are looping over the NonFieldErrors and displaying
 | 
				
			||||||
 | 
					    them, if any exist -->
 | 
				
			||||||
 | 
					    {{range .Form.NonFieldErrors}}
 | 
				
			||||||
 | 
					        <div class='error'>{{.}}</div>
 | 
				
			||||||
 | 
					    {{end}}
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Email:</label>
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.email}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <input type='email' name='email' value='{{.Form.Email}}'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Password:</label>
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.password}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <input type='password' name='password'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <input type='submit' value='Login'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					</form>
 | 
				
			||||||
 | 
					{{end}}
 | 
				
			||||||
@ -0,0 +1,32 @@
 | 
				
			|||||||
 | 
					{{ define "title"}}Signup{{end}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					{{define "main"}}
 | 
				
			||||||
 | 
					<form action='/user/signup' method='POST' novalidate>
 | 
				
			||||||
 | 
					    <!-- Include the CSRF token -->
 | 
				
			||||||
 | 
					    <input type='hidden' name='csrf_token' value='{{.CSRFToken}}'>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Name:</label>
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.name}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <input type='text' name='name' value='{{.Form.Name}}'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Email:</label>
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.email}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <input type='email' name='email' value='{{.Form.Email}}'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <label>Password:</label>
 | 
				
			||||||
 | 
					        {{with .Form.FieldErrors.password}}
 | 
				
			||||||
 | 
					            <label class='error'>{{.}}</label>
 | 
				
			||||||
 | 
					        {{end}}
 | 
				
			||||||
 | 
					        <input type='password' name='password'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    <div>
 | 
				
			||||||
 | 
					        <input type='submit' value='Signup'>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					</form>
 | 
				
			||||||
 | 
					{{end}}
 | 
				
			||||||
@ -1,5 +1,22 @@
 | 
				
			|||||||
{{define "nav" -}}
 | 
					{{define "nav" -}}
 | 
				
			||||||
        <nav>
 | 
					        <nav>
 | 
				
			||||||
 | 
					            <div>
 | 
				
			||||||
                <a href='/'>Home</a>
 | 
					                <a href='/'>Home</a>
 | 
				
			||||||
 | 
					                {{ if .IsAuthenticated }}
 | 
				
			||||||
 | 
					                <a href='/snippet/create'>Create snippet</a>
 | 
				
			||||||
 | 
					                {{ end }}
 | 
				
			||||||
 | 
					            </div>
 | 
				
			||||||
 | 
					            <div>
 | 
				
			||||||
 | 
					                {{ if .IsAuthenticated }}
 | 
				
			||||||
 | 
					                <form action='/user/logout' method='POST'>
 | 
				
			||||||
 | 
					                        <!-- Include the CSRF token -->
 | 
				
			||||||
 | 
					                    <input type='hidden' name='csrf_token' value='{{.CSRFToken}}'>
 | 
				
			||||||
 | 
					                    <button>Logout</button>
 | 
				
			||||||
 | 
					                </form>
 | 
				
			||||||
 | 
					                {{ else }}
 | 
				
			||||||
 | 
					                <a href='/user/signup'>Signup</a>
 | 
				
			||||||
 | 
					                <a href='/user/login'>Login</a>
 | 
				
			||||||
 | 
					                {{ end }}
 | 
				
			||||||
 | 
					            </div>
 | 
				
			||||||
        </nav>
 | 
					        </nav>
 | 
				
			||||||
{{- end}}
 | 
					{{- end}}
 | 
				
			||||||
					Loading…
					
					
				
		Reference in New Issue