tests: extract _build_session_factory to a shared conftest fixture
The same per-test sqlite-engine setup was duplicated across 14 test files (~30 lines each). Consolidated into a single async fixture `db_factory` in tests/conftest.py; tests now take db_factory as a parameter and use `async with db_factory() as session` directly. No behaviour change — same function-scope, same in-memory schema created via Base.metadata.create_all, same app.db._engine / _session_factory rebinding so module-level helpers see the test engine. Just ~420 lines of boilerplate removed.
This commit is contained in:
parent
b13caa4c51
commit
dcc2c07111
5 changed files with 167 additions and 250 deletions
|
|
@ -4,26 +4,6 @@ from __future__ import annotations
|
|||
import pytest
|
||||
|
||||
|
||||
def _build_session_factory(tmp_path):
|
||||
"""Spin up a fresh in-memory schema and return (engine, factory).
|
||||
Matches the pattern used in tests/test_referral_conversion.py."""
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from app import db as db_mod
|
||||
from app.db import Base
|
||||
import app.models # noqa: F401 — registers models on Base.metadata
|
||||
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/csv.db")
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
db_mod._engine = engine
|
||||
db_mod._session_factory = factory
|
||||
|
||||
async def _setup():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
return engine, factory, _setup
|
||||
|
||||
|
||||
def test_csv_format_template_model_columns():
|
||||
"""Model exposes every column the spec requires, with correct types."""
|
||||
|
|
@ -310,7 +290,7 @@ async def test_extract_mapping_via_llm_provider_failure_wraps():
|
|||
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
|
||||
|
||||
|
||||
async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
|
||||
async def test_parse_with_llm_cache_miss_inserts_template(db_factory):
|
||||
from unittest.mock import AsyncMock
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -318,8 +298,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
|
|||
from app.services.llm_csv_parser import parse_with_llm
|
||||
from app.services.openrouter import LogResult
|
||||
|
||||
_, factory, setup = _build_session_factory(tmp_path)
|
||||
await setup()
|
||||
factory = db_factory
|
||||
|
||||
raw = (
|
||||
b"Symbol,Quantity,Avg Price,Currency\n"
|
||||
|
|
@ -356,7 +335,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
|
|||
assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user"
|
||||
|
||||
|
||||
async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
|
||||
async def test_parse_with_llm_cache_hit_skips_llm(db_factory):
|
||||
from unittest.mock import AsyncMock
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -364,8 +343,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
|
|||
from app.models import CsvFormatTemplate
|
||||
from app.services.llm_csv_parser import _fingerprint, parse_with_llm
|
||||
|
||||
_, factory, setup = _build_session_factory(tmp_path)
|
||||
await setup()
|
||||
factory = db_factory
|
||||
|
||||
headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
|
||||
fp = _fingerprint(headers)
|
||||
|
|
@ -411,7 +389,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
|
|||
assert rows[0].use_count == 2
|
||||
|
||||
|
||||
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
|
||||
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(db_factory):
|
||||
from unittest.mock import AsyncMock
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -419,8 +397,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
|
|||
from app.models import CsvFormatTemplate
|
||||
from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm
|
||||
|
||||
_, factory, setup = _build_session_factory(tmp_path)
|
||||
await setup()
|
||||
factory = db_factory
|
||||
|
||||
headers = ["Symbol", "Quantity"]
|
||||
fp = _fingerprint(headers)
|
||||
|
|
@ -452,7 +429,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
|
|||
assert len(rows) == 1
|
||||
|
||||
|
||||
async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch):
|
||||
async def test_parse_portfolio_route_falls_through_to_llm(db_factory, monkeypatch):
|
||||
"""End-to-end: T212 parser raises CSVImportError, LLM fallback runs,
|
||||
response shape matches the existing JSON contract."""
|
||||
from io import BytesIO
|
||||
|
|
@ -461,8 +438,7 @@ async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch)
|
|||
|
||||
from fastapi import UploadFile
|
||||
|
||||
_, factory, setup = _build_session_factory(tmp_path)
|
||||
await setup()
|
||||
factory = db_factory
|
||||
|
||||
import app.services.llm_csv_parser as mod
|
||||
from app.services.openrouter import LogResult
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue