tests: extract _build_session_factory to a shared conftest fixture

The same per-test sqlite-engine setup was duplicated across 14 test
files (~30 lines each). Consolidated into a single async fixture
`db_factory` in tests/conftest.py; tests now take db_factory as a
parameter and use `async with db_factory() as session` directly.

No behaviour change — same function-scope, same in-memory schema
created via Base.metadata.create_all, same app.db._engine /
_session_factory rebinding so module-level helpers see the test
engine. Just ~420 lines of boilerplate removed.
This commit is contained in:
Giorgio Gilestro 2026-05-27 20:50:09 +02:00
parent b13caa4c51
commit dcc2c07111
5 changed files with 167 additions and 250 deletions

View file

@ -4,26 +4,6 @@ from __future__ import annotations
import pytest
def _build_session_factory(tmp_path):
"""Spin up a fresh in-memory schema and return (engine, factory).
Matches the pattern used in tests/test_referral_conversion.py."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401 — registers models on Base.metadata
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/csv.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _setup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, factory, _setup
def test_csv_format_template_model_columns():
"""Model exposes every column the spec requires, with correct types."""
@ -310,7 +290,7 @@ async def test_extract_mapping_via_llm_provider_failure_wraps():
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
async def test_parse_with_llm_cache_miss_inserts_template(db_factory):
from unittest.mock import AsyncMock
from sqlalchemy import select
@ -318,8 +298,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
from app.services.llm_csv_parser import parse_with_llm
from app.services.openrouter import LogResult
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
raw = (
b"Symbol,Quantity,Avg Price,Currency\n"
@ -356,7 +335,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user"
async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
async def test_parse_with_llm_cache_hit_skips_llm(db_factory):
from unittest.mock import AsyncMock
from sqlalchemy import select
@ -364,8 +343,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
from app.models import CsvFormatTemplate
from app.services.llm_csv_parser import _fingerprint, parse_with_llm
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
fp = _fingerprint(headers)
@ -411,7 +389,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
assert rows[0].use_count == 2
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(db_factory):
from unittest.mock import AsyncMock
from sqlalchemy import select
@ -419,8 +397,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
from app.models import CsvFormatTemplate
from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
headers = ["Symbol", "Quantity"]
fp = _fingerprint(headers)
@ -452,7 +429,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
assert len(rows) == 1
async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch):
async def test_parse_portfolio_route_falls_through_to_llm(db_factory, monkeypatch):
"""End-to-end: T212 parser raises CSVImportError, LLM fallback runs,
response shape matches the existing JSON contract."""
from io import BytesIO
@ -461,8 +438,7 @@ async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch)
from fastapi import UploadFile
_, factory, setup = _build_session_factory(tmp_path)
await setup()
factory = db_factory
import app.services.llm_csv_parser as mod
from app.services.openrouter import LogResult