Fixed unit tests for new filtering query

main
Drew Bednar 1 week ago
parent 04489650ad
commit 6eb3709c00

@ -19,7 +19,21 @@ import (
) )
var defaultFixedTime = time.Date(2026, 1, 25, 10, 10, 40, 0, time.UTC) var defaultFixedTime = time.Date(2026, 1, 25, 10, 10, 40, 0, time.UTC)
var getAllReturnColumns = []string{"id", "created_at", "title", "year", "runtime", "genres", "version"} var getAllReturnColumns = []string{"count", "id", "created_at", "title", "year", "runtime", "genres", "version"}
// This regex matches the SQL query structure while allowing flexible whitespace and any values substituted into the ORDER BY fmt.Sprintf("%s %s") fields.
// The ORDER BY column and direction are intentionally matched with broad patterns since pgxmock only needs to validate the overall query shape.
// const getAllQueryRegex = `(?is)
// SELECT\s+count\(\*\)\s+OVER\(\),\s+id,\s+created_at,\s+title,\s+year,\s+runtime,\s+genres,\s+version\s+
// FROM\s+movies\s+
// WHERE\s+\(to_tsvector\('simple',\s+title\)\s+@@\s+plainto_tsquery\('simple',\s+\$1\)\s+OR\s+\$1\s+=\s+''\)\s+
// AND\s+\(genres\s+@>\s+\$2\s+OR\s+\$2\s+=\s+'\{\}'\)\s+
// ORDER\s+BY\s+.+\s+.+,\s+id\s+ASC\s+
// LIMIT\s+\$3\s+OFFSET\s+\$4`
// This regex matches the SQL query structure while tolerating whitespace normalization performed by pgx.
// The ORDER BY column and direction generated by fmt.Sprintf("%s %s") are matched generically so tests do not depend on exact formatting.
const getAllQueryRegex = `(?is)^SELECT\s+count\(\*\)\s+OVER\(\),\s+id,\s+created_at,\s+title,\s+year,\s+runtime,\s+genres,\s+version\s+FROM\s+movies\s+WHERE\s+\(to_tsvector\('simple',\s+title\)\s+@@\s+plainto_tsquery\('simple',\s+\$1\)\s+OR\s+\$1\s+=\s+''\)\s+AND\s+\(genres\s+@>\s+\$2\s+OR\s+\$2\s+=\s+'\{\}'\)\s+ORDER\s+BY\s+\S+\s+\S+,\s+id\s+ASC\s+LIMIT\s+\$3\s+OFFSET\s+\$4$`
func TestHealthRoute(t *testing.T) { func TestHealthRoute(t *testing.T) {
respRec := httptest.NewRecorder() respRec := httptest.NewRecorder()
@ -220,7 +234,8 @@ func TestGetMovieHandler(t *testing.T) {
}, },
} }
mockPool.ExpectQuery("SELECT id, created_at, title, year, runtime, genres, version FROM movies"). //(?i) This is a flag modifier that turns on case-insensitive matching
mockPool.ExpectQuery(`(?i)SELECT\s+id,\s+created_at,\s+title,\s+year,\s+runtime,\s+genres,\s+version\s+FROM\s+movies\s+WHERE\s+id`).
WithArgs(int64(1337)).WillReturnRows( WithArgs(int64(1337)).WillReturnRows(
pgxmock.NewRows([]string{"id", "created_at", "title", "year", "runtime", "genres", "version"}). pgxmock.NewRows([]string{"id", "created_at", "title", "year", "runtime", "genres", "version"}).
AddRow(int64(1337), time.Now(), "a laura is born", 1990, 36, []string{"family", "wife"}, 1), // These values will be scanned into the struct AddRow(int64(1337), time.Now(), "a laura is born", 1990, 36, []string{"family", "wife"}, 1), // These values will be scanned into the struct
@ -321,14 +336,20 @@ func TestListMovieHandler(t *testing.T) {
}, },
} }
rows := pgxmock.NewRows([]string{"id", "created_at", "title", "year", "runtime", "genres", "version"}) rows := pgxmock.NewRows([]string{"count", "id", "created_at", "title", "year", "runtime", "genres", "version"})
for _, m := range movies { for _, m := range movies {
rows.AddRow(m.ID, m.CreatedAt, m.Title, m.Year, m.Runtime, m.Genres, m.Version) rows.AddRow(2,
m.ID, m.CreatedAt, m.Title, m.Year, m.Runtime, m.Genres, m.Version)
} }
mockPool.ExpectQuery(`SELECT id, created_at, title, year, runtime, genres, version mockPool.ExpectQuery(getAllQueryRegex).
FROM movies WithArgs(
ORDER BY id ASC`).WillReturnRows(rows) pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
).
WillReturnRows(rows)
r, err := http.NewRequest(http.MethodGet, "/v1/movies", nil) r, err := http.NewRequest(http.MethodGet, "/v1/movies", nil)
assert.NilError(t, err) assert.NilError(t, err)
@ -363,9 +384,16 @@ func TestListHandlerServerError(t *testing.T) {
respRec := httptest.NewRecorder() respRec := httptest.NewRecorder()
mockPool, err := pgxmock.NewPool() mockPool, err := pgxmock.NewPool()
assert.NilError(t, err) assert.NilError(t, err)
errorRows := pgxmock.NewRows(getAllReturnColumns).AddRow(1, time.Now(), "will error", 2026, 120, []string{}, 1).RowError(0, fmt.Errorf("network connection lost")) errorRows := pgxmock.NewRows(getAllReturnColumns).AddRow(1, 1, time.Now(), "will error", 2026, 120, []string{}, 1).RowError(0, fmt.Errorf("network connection lost"))
mockPool.ExpectQuery("SELECT").WillReturnRows(errorRows) mockPool.ExpectQuery("SELECT").
WithArgs(
pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
).
WillReturnRows(errorRows)
r, err := http.NewRequest(http.MethodGet, "/v1/movies", nil) r, err := http.NewRequest(http.MethodGet, "/v1/movies", nil)
app := newTestApplication(mockPool) app := newTestApplication(mockPool)
@ -381,11 +409,11 @@ func TestListHandlerValidation(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
defer mockPool.Close() defer mockPool.Close()
getAllQuery := ` // getAllQuery := `
SELECT id, created_at, title, year, runtime, genres, version // SELECT count, id, created_at, title, year, runtime, genres, version
FROM movies // FROM movies
ORDER BY id ASC // ORDER BY id ASC
` // `
testTable := []struct { testTable := []struct {
name string name string
@ -487,7 +515,14 @@ func TestListHandlerValidation(t *testing.T) {
if test.wantCode == http.StatusOK { if test.wantCode == http.StatusOK {
// TODO expand return values and parameterize // TODO expand return values and parameterize
// empty return is fine for validator test // empty return is fine for validator test
mockPool.ExpectQuery(getAllQuery).WillReturnRows(mockPool.NewRows(getAllReturnColumns)) mockPool.ExpectQuery(getAllQueryRegex).
WithArgs(
pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
).
WillReturnRows(mockPool.NewRows(getAllReturnColumns))
} }
r, err := http.NewRequest(http.MethodGet, test.query, nil) r, err := http.NewRequest(http.MethodGet, test.query, nil)
@ -587,10 +622,16 @@ func TestListMoviesFilters(t *testing.T) {
rows := pgxmock.NewRows(getAllReturnColumns) rows := pgxmock.NewRows(getAllReturnColumns)
for _, m := range test.expextedMovies { for _, m := range test.expextedMovies {
rows.AddRow(m.ID, m.CreatedAt, m.Title, m.Year, m.Runtime, m.Genres, m.Version) rows.AddRow(3, m.ID, m.CreatedAt, m.Title, m.Year, m.Runtime, m.Genres, m.Version)
} }
mockPool.ExpectQuery("SELECT"). mockPool.ExpectQuery("SELECT").
WithArgs(
pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
pgxmock.AnyArg(),
).
WillReturnRows(rows) WillReturnRows(rows)
assert.NilError(t, err) assert.NilError(t, err)

@ -36,31 +36,41 @@ type MovieModel struct {
logger *slog.Logger logger *slog.Logger
} }
// So in regards to pulling the queries out into constants. I think it makes sense
// in a world where the query gets really gnarly, but here it ends up violating code locality
// and honestly doesn't provide that much utility in pgmock, because a regex is probably a better
// fit anyways for testing and mocks.
// If you the whitespace becomes a problem pgxmocks should use a real regex
// or a strings.Join(strings.Fields(q), " ") helper function could be used to normalizes
// all the sql strings. Postgres doesn't care about whitespace though.
const InsertMovieQuery = `
INSERT INTO movies (title, year, runtime, genres)
VALUES ($1, $2, $3, $4)
RETURNING id, created_at, version`
func (m MovieModel) Insert(ctx context.Context, movie *Movie) error { func (m MovieModel) Insert(ctx context.Context, movie *Movie) error {
query := `
INSERT INTO movies (title, year, runtime, genres)
VALUES ($1, $2, $3, $4)
RETURNING id, created_at, version`
args := []any{movie.Title, movie.Year, movie.Runtime, movie.Genres} args := []any{movie.Title, movie.Year, movie.Runtime, movie.Genres}
row := m.db.QueryRow(ctx, query, args...) row := m.db.QueryRow(ctx, InsertMovieQuery, args...)
// Insert is mutating the Movie struct // Insert is mutating the Movie struct
err := row.Scan(&movie.ID, &movie.CreatedAt, &movie.Version) err := row.Scan(&movie.ID, &movie.CreatedAt, &movie.Version)
return err return err
} }
const GetMovieQuery = `
SELECT id, created_at, title, year, runtime, genres, version
FROM movies
WHERE id = $1`
func (m MovieModel) Get(ctx context.Context, id int64) (*Movie, error) { func (m MovieModel) Get(ctx context.Context, id int64) (*Movie, error) {
// safety validation // safety validation
if id < 1 { if id < 1 {
return nil, ErrRecordNotFound return nil, ErrRecordNotFound
} }
query := `
SELECT id, created_at, title, year, runtime, genres, version
FROM movies
WHERE id = $1
`
// Mimicking a long running query. FYI since this changes the number of returned // Mimicking a long running query. FYI since this changes the number of returned
// fields, so I implmented the throwaway variable to take care of that // fields, so I implmented the throwaway variable to take care of that
// query := ` // query := `
@ -71,7 +81,7 @@ func (m MovieModel) Get(ctx context.Context, id int64) (*Movie, error) {
var movie Movie var movie Movie
err := m.db.QueryRow(ctx, query, id).Scan( err := m.db.QueryRow(ctx, GetMovieQuery, id).Scan(
// &[]byte{}, // throwaway the pg_sleep value // &[]byte{}, // throwaway the pg_sleep value
&movie.ID, &movie.ID,
&movie.CreatedAt, &movie.CreatedAt,
@ -93,6 +103,12 @@ func (m MovieModel) Get(ctx context.Context, id int64) (*Movie, error) {
return &movie, nil return &movie, nil
} }
const UpdateMovieQuery = `
UPDATE movies
SET title = $1, year = $2, runtime = $3, genres = $4, version = version + 1
WHERE id = $5 and version = $6
RETURNING version`
func (m MovieModel) Update(ctx context.Context, movie *Movie) error { func (m MovieModel) Update(ctx context.Context, movie *Movie) error {
// Using version here as an optimistic lock. Look up optimistic vs pessimistic locking // Using version here as an optimistic lock. Look up optimistic vs pessimistic locking
// https://stackoverflow.com/questions/129329/optimistic-vs-pessimistic-locking/129397#129397 // https://stackoverflow.com/questions/129329/optimistic-vs-pessimistic-locking/129397#129397
@ -101,13 +117,6 @@ func (m MovieModel) Update(ctx context.Context, movie *Movie) error {
// cause issues. If you don't want version to be guessable then a UUID generated by the DB is suitable. // cause issues. If you don't want version to be guessable then a UUID generated by the DB is suitable.
// Example: SET ... version = uuid_generate_v4() // Example: SET ... version = uuid_generate_v4()
query := `
UPDATE movies
SET title = $1, year = $2, runtime = $3, genres = $4, version = version + 1
WHERE id = $5 and version = $6
RETURNING version
`
args := []any{ args := []any{
movie.Title, movie.Title,
movie.Year, movie.Year,
@ -118,7 +127,7 @@ func (m MovieModel) Update(ctx context.Context, movie *Movie) error {
} }
// Will not return any rows if the version number has already changed. // Will not return any rows if the version number has already changed.
err := m.db.QueryRow(ctx, query, args...).Scan(&movie.Version) err := m.db.QueryRow(ctx, UpdateMovieQuery, args...).Scan(&movie.Version)
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, pgx.ErrNoRows): case errors.Is(err, pgx.ErrNoRows):
@ -137,14 +146,15 @@ func (m MovieModel) Update(ctx context.Context, movie *Movie) error {
// none the less. If using sqlc you'd probably not even worry about implementing // none the less. If using sqlc you'd probably not even worry about implementing
// it this way, but I am here to learn so. // it this way, but I am here to learn so.
const DeleteMovieQuery = `
DELETE FROM movies WHERE id = $1`
func (m MovieModel) Delete(ctx context.Context, id int64) (err error) { func (m MovieModel) Delete(ctx context.Context, id int64) (err error) {
if id < 1 { if id < 1 {
return ErrRecordNotFound return ErrRecordNotFound
} }
query := `DELETE FROM movies WHERE id = $1`
tx, err := m.db.BeginTx(ctx, pgx.TxOptions{}) tx, err := m.db.BeginTx(ctx, pgx.TxOptions{})
if err != nil { if err != nil {
return err return err
@ -160,7 +170,7 @@ func (m MovieModel) Delete(ctx context.Context, id int64) (err error) {
}() }()
var cmd pgconn.CommandTag var cmd pgconn.CommandTag
cmd, err = tx.Exec(ctx, query, id) cmd, err = tx.Exec(ctx, DeleteMovieQuery, id)
if err != nil { if err != nil {
return err return err
} }
@ -172,6 +182,14 @@ func (m MovieModel) Delete(ctx context.Context, id int64) (err error) {
return nil return nil
} }
const GetAllMoviesQueryTemplate = `
SELECT count(*) OVER(), id, created_at, title, year, runtime, genres, version
FROM movies
WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
AND (genres @> $2 OR $2 = '{}')
ORDER BY %s %s, id ASC
LIMIT $3 OFFSET $4`
func (m MovieModel) GetAll(ctx context.Context, title string, genres []string, filters Filters) ([]*Movie, Metadata, error) { func (m MovieModel) GetAll(ctx context.Context, title string, genres []string, filters Filters) ([]*Movie, Metadata, error) {
// OLD query // OLD query
// query := ` // query := `
@ -201,13 +219,7 @@ func (m MovieModel) GetAll(ctx context.Context, title string, genres []string, f
// Using a window function to produce a totalRecords count using the WHERE parameters of // Using a window function to produce a totalRecords count using the WHERE parameters of
// the query // the query
query := fmt.Sprintf(` query := fmt.Sprintf(GetAllMoviesQueryTemplate, filters.sortColumn(), filters.sortDirection())
SELECT count(*) OVER(), id, created_at, title, year, runtime, genres, version
FROM movies
WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
AND (genres @> $2 OR $2 = '{}')
ORDER BY %s %s, id ASC
LIMIT $3 OFFSET $4`, filters.sortColumn(), filters.sortDirection())
// ctx want some timeout for queries. When used in the handler the context passed should // ctx want some timeout for queries. When used in the handler the context passed should
// be the r.Context. Since cancel functions are inherited it will cancel on client // be the r.Context. Since cancel functions are inherited it will cancel on client

Loading…
Cancel
Save