From da23ea767ee822d22e8dbccfdf7a6c4e42ca16b1 Mon Sep 17 00:00:00 2001 From: Drew Bednar Date: Mon, 9 Oct 2023 17:10:51 -0400 Subject: [PATCH] Adding user route tests --- htmx_contact/user.py | 1 - tests/conftest.py | 10 +++++++++- tests/test_routes.py | 38 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/htmx_contact/user.py b/htmx_contact/user.py index 70f01d0..3fbda38 100644 --- a/htmx_contact/user.py +++ b/htmx_contact/user.py @@ -39,7 +39,6 @@ def user_login(): logger.info(f"Received login request from {data.email}") data = LoginValidator.model_validate(dict(request.form)) - with Session() as session: select_user_stmt = select(User).where(User.primary_email == data.email) user = session.scalar(select_user_stmt) diff --git a/tests/conftest.py b/tests/conftest.py index 4b4e0a2..1b4cf69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from flask_login.test_client import FlaskLoginClient @@ -13,7 +15,7 @@ def test_config(): @pytest.fixture() def test_user(): - return User(id=1, primary_email="test", password="password", username="test") + return User(id=1, primary_email="test", password="test", username="test") @pytest.fixture() @@ -46,3 +48,9 @@ def client(app, test_user): @pytest.fixture() def runner(app): return app.test_cli_runner() + + +@pytest.fixture() +def sqla_session(): + with patch("htmx_contact.Session") as ms: + yield ms diff --git a/tests/test_routes.py b/tests/test_routes.py index 9624211..30ab975 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,3 +1,7 @@ +from unittest.mock import Mock +from unittest.mock import patch + + def test_index_redirect_to_contacts(client): response = client.get("/") assert response.status_code == 302 @@ -10,5 +14,35 @@ def test_login_get_contacts(anynomous_client): def test_get_contacts(client): - response = client.get("/contacts", follow_redirects=True) - assert response.status_code == 200 + with client.session_transaction() as session: + response = client.get("/contacts", follow_redirects=True) + assert response.status_code == 200 + assert session["_user_id"] == 1 + + +def test_user_logout(client): + # need an active request context to check session. + with client.session_transaction() as session: + assert session["_user_id"] == 1 + + response = client.get("/user/logout", follow_redirects=False) + assert response.status_code == 302 + + with client.session_transaction() as session: + assert session.get("_user_id") is None + + +def test_user_login(anynomous_client, test_user): + with patch("htmx_contact.user.Session") as sqla_session: + mock_session_instance = Mock(name="slqasession") + mock_session_instance.configure_mock(**{"scalar.return_value": test_user}) + sqla_session.configure_mock(**{"return_value.__enter__.return_value": mock_session_instance}) + + with anynomous_client.session_transaction() as session: + assert session.get("_user_id") is None + + response = anynomous_client.post("/user/login", data={"email": "test", "password": "test"}) + assert response.status_code == 302 + + with anynomous_client.session_transaction() as session: + assert session.get("_user_id") == test_user.id