tests: extract _build_session_factory to a shared conftest fixture

The same per-test sqlite-engine setup was duplicated across 14 test
files (~30 lines each). Consolidated into a single async fixture
`db_factory` in tests/conftest.py; tests now take db_factory as a
parameter and use `async with db_factory() as session` directly.

No behaviour change — same function-scope, same in-memory schema
created via Base.metadata.create_all, same app.db._engine /
_session_factory rebinding so module-level helpers see the test
engine. Just ~420 lines of boilerplate removed.
This commit is contained in:
Giorgio Gilestro 2026-05-27 20:50:09 +02:00
parent b13caa4c51
commit dcc2c07111
5 changed files with 167 additions and 250 deletions

View file

@ -17,3 +17,40 @@ sys.path.insert(0, str(ROOT))
# Sentinel env so importing app.config doesn't try to read a missing .env.
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
os.environ.setdefault("CASSANDRA_MOCK", "1")
import pytest
@pytest.fixture
async def db_factory(tmp_path):
"""Per-test sqlite engine + async session factory.
Creates a fresh sqlite database file under ``tmp_path``, applies
``Base.metadata.create_all``, and rebinds ``app.db._engine`` /
``app.db._session_factory`` so module-level helpers (which look
these up at call time) see the test engine.
Yields the ``async_sessionmaker``. Tests use it like:
async def test_foo(db_factory):
async with db_factory() as session:
...
"""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401 — registers models on Base.metadata
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/test.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield factory
await engine.dispose()

View file

@ -4,26 +4,6 @@ from __future__ import annotations
import pytest
def _build_session_factory(tmp_path):
"""Spin up a fresh in-memory schema and return (engine, factory).
Matches the pattern used in tests/test_referral_conversion.py."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401 — registers models on Base.metadata
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/csv.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _setup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, factory, _setup
def test_csv_format_template_model_columns():
"""Model exposes every column the spec requires, with correct types."""
@ -310,7 +290,7 @@ async def test_extract_mapping_via_llm_provider_failure_wraps():
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
async def test_parse_with_llm_cache_miss_inserts_template(db_factory):
from unittest.mock import AsyncMock
from sqlalchemy import select
@ -318,8 +298,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
from app.services.llm_csv_parser import parse_with_llm
from app.services.openrouter import LogResult
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
raw = (
b"Symbol,Quantity,Avg Price,Currency\n"
@ -356,7 +335,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user"
async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
async def test_parse_with_llm_cache_hit_skips_llm(db_factory):
from unittest.mock import AsyncMock
from sqlalchemy import select
@ -364,8 +343,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
from app.models import CsvFormatTemplate
from app.services.llm_csv_parser import _fingerprint, parse_with_llm
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
fp = _fingerprint(headers)
@ -411,7 +389,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
assert rows[0].use_count == 2
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(db_factory):
from unittest.mock import AsyncMock
from sqlalchemy import select
@ -419,8 +397,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
from app.models import CsvFormatTemplate
from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
headers = ["Symbol", "Quantity"]
fp = _fingerprint(headers)
@ -452,7 +429,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
assert len(rows) == 1
async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch):
async def test_parse_portfolio_route_falls_through_to_llm(db_factory, monkeypatch):
"""End-to-end: T212 parser raises CSVImportError, LLM fallback runs,
response shape matches the existing JSON contract."""
from io import BytesIO
@ -461,8 +438,7 @@ async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch)
from fastapi import UploadFile
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
import app.services.llm_csv_parser as mod
from app.services.openrouter import LogResult

View file

@ -5,25 +5,6 @@ from __future__ import annotations
import pytest
def _build_session_factory(tmp_path):
"""Per-test sqlite engine + factory. Mirrors test_referral_conversion.py."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401 — registers models on Base.metadata
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/loc.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _setup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, factory, _setup
def test_user_has_lang_column_with_default_en():
from sqlalchemy import inspect
@ -55,7 +36,7 @@ def test_strategic_log_translation_model_columns():
assert cols["content_md"].nullable is False
async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypatch):
async def test_log_translation_fanout_no_active_non_en_users(db_factory, monkeypatch):
"""When no users have an active non-en lang, the fan-out makes no
translation calls and no rows are inserted."""
from unittest.mock import AsyncMock
@ -65,8 +46,7 @@ async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypat
from app.models import StrategicLog, StrategicLogTranslation, User
from app.jobs import ai_log_job
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
fake_translate = AsyncMock()
monkeypatch.setattr(ai_log_job, "translate", fake_translate)
@ -92,7 +72,7 @@ async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypat
assert rows == []
async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
async def test_log_translation_fanout_italian_user(db_factory, monkeypatch):
"""One user at lang=it triggers one translation; the row lands with
the right lang and log_id."""
from sqlalchemy import select
@ -102,8 +82,7 @@ async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
from app.services.openrouter import LogResult
from app.jobs import ai_log_job
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
async def _fake_translate(client, text, target_lang):
assert target_lang == "it"
@ -139,7 +118,7 @@ async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
assert row.llm_cost_usd == pytest.approx(0.00002)
async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, monkeypatch):
async def test_log_translation_fanout_per_language_failure_isolated(db_factory, monkeypatch):
"""If one language's translation fails, the others (if any) still land
and the job does not raise."""
from sqlalchemy import select
@ -148,8 +127,7 @@ async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, mo
from app.models import StrategicLog, StrategicLogTranslation, User
from app.jobs import ai_log_job
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
async def _fake_translate(client, text, target_lang):
raise RuntimeError("upstream down")
@ -175,7 +153,7 @@ async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, mo
assert rows == []
async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
async def test_analyse_threads_lang_into_system_prompt(db_factory, monkeypatch):
"""When lang='it', the system prompt sent to call_llm contains
'Respond in Italian.' the LLM does the rest."""
from app.services import portfolio_analysis as pa
@ -191,8 +169,7 @@ async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
)
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
payload = {
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
@ -213,7 +190,7 @@ async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
assert "Respond in Italian" in system
async def test_analyse_no_clause_when_lang_is_en(tmp_path, monkeypatch):
async def test_analyse_no_clause_when_lang_is_en(db_factory, monkeypatch):
from app.services import portfolio_analysis as pa
from app.services.openrouter import LogResult
@ -227,8 +204,7 @@ async def test_analyse_no_clause_when_lang_is_en(tmp_path, monkeypatch):
)
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
payload = {
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
@ -328,13 +304,12 @@ def test_digest_pick_variant_uses_user_lang():
async def test_patch_language_accepts_active(tmp_path):
async def test_patch_language_accepts_active(db_factory):
"""PATCH /api/settings/language accepts 'en' and 'it' and persists."""
from app.models import User
from app.routers.api import patch_language_prefs, LanguagePrefsIn
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
async with factory() as session:
session.add(User(id=20, email="u@x", tier="paid", lang="en"))
@ -358,14 +333,13 @@ async def test_patch_language_accepts_active(tmp_path):
assert user.lang == "it"
async def test_patch_language_rejects_wip(tmp_path):
async def test_patch_language_rejects_wip(db_factory):
"""PATCH rejects 'es'/'fr'/'de'/'xx' with 400 — ACTIVE_LANGUAGES gate."""
from fastapi import HTTPException
from app.models import User
from app.routers.api import patch_language_prefs, LanguagePrefsIn
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
async with factory() as session:
session.add(User(id=21, email="u2@x", tier="paid", lang="en"))

View file

@ -23,29 +23,6 @@ import pytest
# ---------------------------------------------------------------------------
def _build_session_factory(tmp_path):
"""Spin up a fresh in-memory schema and return (engine, factory).
Mirrors test_stripe_billing._build_app's seeding strategy but
skips the FastAPI app most conversion tests only need the
session factory."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/conv.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _seed():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
asyncio.run(_seed())
return factory
async def _add_pair(factory, *, referrer_id=1, referred_id=2):
"""Insert a referrer + referred user pair and a linking Referral row.
Returns nothing tests re-fetch via the factory."""
@ -68,7 +45,7 @@ async def _add_pair(factory, *, referrer_id=1, referred_id=2):
# ---------------------------------------------------------------------------
def test_first_conversion_credits_both_parties(tmp_path):
async def test_first_conversion_credits_both_parties(db_factory):
"""Calling convert_referral on a freshly-paid referred user should
extend credit_until by REFERRAL_CREDIT_DAYS for BOTH the buyer and
the referrer, and stamp converted_at + credited_at."""
@ -77,10 +54,9 @@ def test_first_conversion_credits_both_parties(tmp_path):
REFERRAL_CREDIT_DAYS, convert_referral,
)
factory = _build_session_factory(tmp_path)
asyncio.run(_add_pair(factory))
factory = db_factory
await _add_pair(factory)
async def _run():
async with factory() as s:
referred = await s.get(User, 2)
ref = await convert_referral(s, referred)
@ -105,20 +81,17 @@ def test_first_conversion_credits_both_parties(tmp_path):
delta_days = (cu - now).total_seconds() / 86400
assert REFERRAL_CREDIT_DAYS - 1 <= delta_days <= REFERRAL_CREDIT_DAYS + 1
asyncio.run(_run())
def test_idempotent_on_repeat_call(tmp_path):
async def test_idempotent_on_repeat_call(db_factory):
"""A second convert_referral call (e.g. from a duplicate webhook or
renewal event) must NOT extend credit a second time. The Referral
row is already stamped, so we should early-return unchanged."""
from app.models import User
from app.services.referral_service import convert_referral
factory = _build_session_factory(tmp_path)
asyncio.run(_add_pair(factory))
factory = db_factory
await _add_pair(factory)
async def _run():
async with factory() as s:
referred = await s.get(User, 2)
await convert_referral(s, referred)
@ -142,35 +115,27 @@ def test_idempotent_on_repeat_call(tmp_path):
assert referrer.credit_until == first_referrer_credit
assert referred.credit_until == first_referred_credit
asyncio.run(_run())
def test_no_referral_row_returns_none(tmp_path):
async def test_no_referral_row_returns_none(db_factory):
"""A user signing up directly (no inviter) has no Referral row.
convert_referral must return None and touch nothing."""
from app.models import User
from app.services.referral_service import convert_referral
factory = _build_session_factory(tmp_path)
factory = db_factory
async def _seed_orphan():
async with factory() as s:
s.add(User(id=9, email="lone@x", tier="free"))
await s.commit()
asyncio.run(_seed_orphan())
async def _run():
async with factory() as s:
user = await s.get(User, 9)
result = await convert_referral(s, user)
assert result is None
assert user.credit_until is None
asyncio.run(_run())
def test_credit_stacks_from_existing_window(tmp_path):
async def test_credit_stacks_from_existing_window(db_factory):
"""If the user already has a future credit_until (admin grant, prior
referral), the new credit should extend from THAT anchor not from
now. Mirrors cli.grant_credit's stacking semantics."""
@ -179,21 +144,17 @@ def test_credit_stacks_from_existing_window(tmp_path):
REFERRAL_CREDIT_DAYS, convert_referral,
)
factory = _build_session_factory(tmp_path)
asyncio.run(_add_pair(factory))
factory = db_factory
await _add_pair(factory)
# Pre-load 30 days of credit on the referred user.
existing = datetime.now(timezone.utc) + timedelta(days=30)
async def _preload():
async with factory() as s:
u = await s.get(User, 2)
u.credit_until = existing
await s.commit()
asyncio.run(_preload())
async def _run():
async with factory() as s:
referred = await s.get(User, 2)
await convert_referral(s, referred)
@ -211,19 +172,16 @@ def test_credit_stacks_from_existing_window(tmp_path):
f"got {cu}, expected ~{expected}"
)
asyncio.run(_run())
def test_deleted_referrer_does_not_crash(tmp_path):
async def test_deleted_referrer_does_not_crash(db_factory):
"""If the referrer's User row has been deleted, the referred user
should still be credited and the Referral still stamped we just
skip the missing referrer."""
from app.models import Referral, User
from app.services.referral_service import convert_referral
factory = _build_session_factory(tmp_path)
factory = db_factory
async def _seed():
from app.db import utcnow
async with factory() as s:
# Referrer with FK SET NULL — we don't delete the row, we
@ -235,9 +193,6 @@ def test_deleted_referrer_does_not_crash(tmp_path):
created_at=utcnow()))
await s.commit()
asyncio.run(_seed())
async def _run():
async with factory() as s:
referred = await s.get(User, 2)
ref = await convert_referral(s, referred)
@ -247,8 +202,6 @@ def test_deleted_referrer_does_not_crash(tmp_path):
# Referred still got their credit even though referrer is gone.
assert referred.credit_until is not None
asyncio.run(_run())
# ---------------------------------------------------------------------------
# Stripe-webhook integration

View file

@ -9,34 +9,13 @@ from unittest.mock import AsyncMock
import pytest
def _build_session_factory(tmp_path):
"""Spin up a fresh in-memory schema and return (engine, factory, setup).
Mirrors tests/test_llm_csv_parser.py / tests/test_referral_conversion.py."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/tv.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _setup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, factory, _setup
async def test_validate_happy_path(tmp_path, monkeypatch):
async def test_validate_happy_path(db_factory, monkeypatch):
from app.routers.ticker_validate import validate_ticker
from app.services.market import Quote
import app.routers.ticker_validate as mod
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
# Mock fetch_yahoo to return a successful quote.
async def _fake_yahoo(client, symbol, label, note, anchor=None):
@ -61,13 +40,12 @@ async def test_validate_happy_path(tmp_path, monkeypatch):
assert result["as_of"] == "2026-05-27"
async def test_validate_unknown_symbol(tmp_path, monkeypatch):
async def test_validate_unknown_symbol(db_factory, monkeypatch):
from app.routers.ticker_validate import validate_ticker
from app.services.market import Quote
import app.routers.ticker_validate as mod
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
# Mock fetch_yahoo to return a Quote with error and no price.
async def _fake_yahoo(client, symbol, label, note, anchor=None):
@ -92,7 +70,7 @@ async def test_validate_empty_symbol_rejects():
assert "required" in result["error"].lower()
async def test_validate_seeds_universe_and_quote(tmp_path, monkeypatch):
async def test_validate_seeds_universe_and_quote(db_factory, monkeypatch):
"""Side-effect check: on success, the symbol is upserted into the
universe and a Quote row is written."""
from sqlalchemy import select
@ -102,8 +80,7 @@ async def test_validate_seeds_universe_and_quote(tmp_path, monkeypatch):
from app.services.market import Quote
import app.routers.ticker_validate as mod
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
upsert_calls: list[list[str]] = []