tests: extract _build_session_factory to a shared conftest fixture

The same per-test sqlite-engine setup was duplicated across 14 test
files (~30 lines each). Consolidated into a single async fixture
`db_factory` in tests/conftest.py; tests now take db_factory as a
parameter and use `async with db_factory() as session` directly.

No behaviour change — same function-scope, same in-memory schema
created via Base.metadata.create_all, same app.db._engine /
_session_factory rebinding so module-level helpers see the test
engine. Just ~420 lines of boilerplate removed.
This commit is contained in:
Giorgio Gilestro 2026-05-27 20:50:09 +02:00
parent b13caa4c51
commit dcc2c07111
5 changed files with 167 additions and 250 deletions

View file

@ -17,3 +17,40 @@ sys.path.insert(0, str(ROOT))
# Sentinel env so importing app.config doesn't try to read a missing .env. # Sentinel env so importing app.config doesn't try to read a missing .env.
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:") os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
os.environ.setdefault("CASSANDRA_MOCK", "1") os.environ.setdefault("CASSANDRA_MOCK", "1")
import pytest
@pytest.fixture
async def db_factory(tmp_path):
"""Per-test sqlite engine + async session factory.
Creates a fresh sqlite database file under ``tmp_path``, applies
``Base.metadata.create_all``, and rebinds ``app.db._engine`` /
``app.db._session_factory`` so module-level helpers (which look
these up at call time) see the test engine.
Yields the ``async_sessionmaker``. Tests use it like:
async def test_foo(db_factory):
async with db_factory() as session:
...
"""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401 — registers models on Base.metadata
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/test.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield factory
await engine.dispose()

View file

@ -4,26 +4,6 @@ from __future__ import annotations
import pytest import pytest
def _build_session_factory(tmp_path):
"""Spin up a fresh in-memory schema and return (engine, factory).
Matches the pattern used in tests/test_referral_conversion.py."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401 — registers models on Base.metadata
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/csv.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _setup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, factory, _setup
def test_csv_format_template_model_columns(): def test_csv_format_template_model_columns():
"""Model exposes every column the spec requires, with correct types.""" """Model exposes every column the spec requires, with correct types."""
@ -310,7 +290,7 @@ async def test_extract_mapping_via_llm_provider_failure_wraps():
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]]) await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
async def test_parse_with_llm_cache_miss_inserts_template(tmp_path): async def test_parse_with_llm_cache_miss_inserts_template(db_factory):
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from sqlalchemy import select from sqlalchemy import select
@ -318,8 +298,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
from app.services.llm_csv_parser import parse_with_llm from app.services.llm_csv_parser import parse_with_llm
from app.services.openrouter import LogResult from app.services.openrouter import LogResult
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
raw = ( raw = (
b"Symbol,Quantity,Avg Price,Currency\n" b"Symbol,Quantity,Avg Price,Currency\n"
@ -356,7 +335,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user" assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user"
async def test_parse_with_llm_cache_hit_skips_llm(tmp_path): async def test_parse_with_llm_cache_hit_skips_llm(db_factory):
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from sqlalchemy import select from sqlalchemy import select
@ -364,8 +343,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
from app.models import CsvFormatTemplate from app.models import CsvFormatTemplate
from app.services.llm_csv_parser import _fingerprint, parse_with_llm from app.services.llm_csv_parser import _fingerprint, parse_with_llm
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
headers = ["Symbol", "Quantity", "Avg Price", "Currency"] headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
fp = _fingerprint(headers) fp = _fingerprint(headers)
@ -411,7 +389,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
assert rows[0].use_count == 2 assert rows[0].use_count == 2
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path): async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(db_factory):
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from sqlalchemy import select from sqlalchemy import select
@ -419,8 +397,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
from app.models import CsvFormatTemplate from app.models import CsvFormatTemplate
from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
headers = ["Symbol", "Quantity"] headers = ["Symbol", "Quantity"]
fp = _fingerprint(headers) fp = _fingerprint(headers)
@ -452,7 +429,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
assert len(rows) == 1 assert len(rows) == 1
async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch): async def test_parse_portfolio_route_falls_through_to_llm(db_factory, monkeypatch):
"""End-to-end: T212 parser raises CSVImportError, LLM fallback runs, """End-to-end: T212 parser raises CSVImportError, LLM fallback runs,
response shape matches the existing JSON contract.""" response shape matches the existing JSON contract."""
from io import BytesIO from io import BytesIO
@ -461,8 +438,7 @@ async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch)
from fastapi import UploadFile from fastapi import UploadFile
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
import app.services.llm_csv_parser as mod import app.services.llm_csv_parser as mod
from app.services.openrouter import LogResult from app.services.openrouter import LogResult

View file

@ -5,25 +5,6 @@ from __future__ import annotations
import pytest import pytest
def _build_session_factory(tmp_path):
"""Per-test sqlite engine + factory. Mirrors test_referral_conversion.py."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
import app.models # noqa: F401 — registers models on Base.metadata
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/loc.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _setup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, factory, _setup
def test_user_has_lang_column_with_default_en(): def test_user_has_lang_column_with_default_en():
from sqlalchemy import inspect from sqlalchemy import inspect
@ -55,7 +36,7 @@ def test_strategic_log_translation_model_columns():
assert cols["content_md"].nullable is False assert cols["content_md"].nullable is False
async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypatch): async def test_log_translation_fanout_no_active_non_en_users(db_factory, monkeypatch):
"""When no users have an active non-en lang, the fan-out makes no """When no users have an active non-en lang, the fan-out makes no
translation calls and no rows are inserted.""" translation calls and no rows are inserted."""
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@ -65,8 +46,7 @@ async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypat
from app.models import StrategicLog, StrategicLogTranslation, User from app.models import StrategicLog, StrategicLogTranslation, User
from app.jobs import ai_log_job from app.jobs import ai_log_job
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
fake_translate = AsyncMock() fake_translate = AsyncMock()
monkeypatch.setattr(ai_log_job, "translate", fake_translate) monkeypatch.setattr(ai_log_job, "translate", fake_translate)
@ -92,7 +72,7 @@ async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypat
assert rows == [] assert rows == []
async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch): async def test_log_translation_fanout_italian_user(db_factory, monkeypatch):
"""One user at lang=it triggers one translation; the row lands with """One user at lang=it triggers one translation; the row lands with
the right lang and log_id.""" the right lang and log_id."""
from sqlalchemy import select from sqlalchemy import select
@ -102,8 +82,7 @@ async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
from app.services.openrouter import LogResult from app.services.openrouter import LogResult
from app.jobs import ai_log_job from app.jobs import ai_log_job
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
async def _fake_translate(client, text, target_lang): async def _fake_translate(client, text, target_lang):
assert target_lang == "it" assert target_lang == "it"
@ -139,7 +118,7 @@ async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
assert row.llm_cost_usd == pytest.approx(0.00002) assert row.llm_cost_usd == pytest.approx(0.00002)
async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, monkeypatch): async def test_log_translation_fanout_per_language_failure_isolated(db_factory, monkeypatch):
"""If one language's translation fails, the others (if any) still land """If one language's translation fails, the others (if any) still land
and the job does not raise.""" and the job does not raise."""
from sqlalchemy import select from sqlalchemy import select
@ -148,8 +127,7 @@ async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, mo
from app.models import StrategicLog, StrategicLogTranslation, User from app.models import StrategicLog, StrategicLogTranslation, User
from app.jobs import ai_log_job from app.jobs import ai_log_job
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
async def _fake_translate(client, text, target_lang): async def _fake_translate(client, text, target_lang):
raise RuntimeError("upstream down") raise RuntimeError("upstream down")
@ -175,7 +153,7 @@ async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, mo
assert rows == [] assert rows == []
async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch): async def test_analyse_threads_lang_into_system_prompt(db_factory, monkeypatch):
"""When lang='it', the system prompt sent to call_llm contains """When lang='it', the system prompt sent to call_llm contains
'Respond in Italian.' the LLM does the rest.""" 'Respond in Italian.' the LLM does the rest."""
from app.services import portfolio_analysis as pa from app.services import portfolio_analysis as pa
@ -191,8 +169,7 @@ async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
) )
monkeypatch.setattr(pa, "call_llm", _fake_call_llm) monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
payload = { payload = {
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0, "positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
@ -213,7 +190,7 @@ async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
assert "Respond in Italian" in system assert "Respond in Italian" in system
async def test_analyse_no_clause_when_lang_is_en(tmp_path, monkeypatch): async def test_analyse_no_clause_when_lang_is_en(db_factory, monkeypatch):
from app.services import portfolio_analysis as pa from app.services import portfolio_analysis as pa
from app.services.openrouter import LogResult from app.services.openrouter import LogResult
@ -227,8 +204,7 @@ async def test_analyse_no_clause_when_lang_is_en(tmp_path, monkeypatch):
) )
monkeypatch.setattr(pa, "call_llm", _fake_call_llm) monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
payload = { payload = {
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0, "positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
@ -328,13 +304,12 @@ def test_digest_pick_variant_uses_user_lang():
async def test_patch_language_accepts_active(tmp_path): async def test_patch_language_accepts_active(db_factory):
"""PATCH /api/settings/language accepts 'en' and 'it' and persists.""" """PATCH /api/settings/language accepts 'en' and 'it' and persists."""
from app.models import User from app.models import User
from app.routers.api import patch_language_prefs, LanguagePrefsIn from app.routers.api import patch_language_prefs, LanguagePrefsIn
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
async with factory() as session: async with factory() as session:
session.add(User(id=20, email="u@x", tier="paid", lang="en")) session.add(User(id=20, email="u@x", tier="paid", lang="en"))
@ -358,14 +333,13 @@ async def test_patch_language_accepts_active(tmp_path):
assert user.lang == "it" assert user.lang == "it"
async def test_patch_language_rejects_wip(tmp_path): async def test_patch_language_rejects_wip(db_factory):
"""PATCH rejects 'es'/'fr'/'de'/'xx' with 400 — ACTIVE_LANGUAGES gate.""" """PATCH rejects 'es'/'fr'/'de'/'xx' with 400 — ACTIVE_LANGUAGES gate."""
from fastapi import HTTPException from fastapi import HTTPException
from app.models import User from app.models import User
from app.routers.api import patch_language_prefs, LanguagePrefsIn from app.routers.api import patch_language_prefs, LanguagePrefsIn
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
async with factory() as session: async with factory() as session:
session.add(User(id=21, email="u2@x", tier="paid", lang="en")) session.add(User(id=21, email="u2@x", tier="paid", lang="en"))

View file

@ -23,29 +23,6 @@ import pytest
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _build_session_factory(tmp_path):
"""Spin up a fresh in-memory schema and return (engine, factory).
Mirrors test_stripe_billing._build_app's seeding strategy but
skips the FastAPI app most conversion tests only need the
session factory."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.db import Base
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/conv.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _seed():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
asyncio.run(_seed())
return factory
async def _add_pair(factory, *, referrer_id=1, referred_id=2): async def _add_pair(factory, *, referrer_id=1, referred_id=2):
"""Insert a referrer + referred user pair and a linking Referral row. """Insert a referrer + referred user pair and a linking Referral row.
Returns nothing tests re-fetch via the factory.""" Returns nothing tests re-fetch via the factory."""
@ -68,7 +45,7 @@ async def _add_pair(factory, *, referrer_id=1, referred_id=2):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_first_conversion_credits_both_parties(tmp_path): async def test_first_conversion_credits_both_parties(db_factory):
"""Calling convert_referral on a freshly-paid referred user should """Calling convert_referral on a freshly-paid referred user should
extend credit_until by REFERRAL_CREDIT_DAYS for BOTH the buyer and extend credit_until by REFERRAL_CREDIT_DAYS for BOTH the buyer and
the referrer, and stamp converted_at + credited_at.""" the referrer, and stamp converted_at + credited_at."""
@ -77,100 +54,88 @@ def test_first_conversion_credits_both_parties(tmp_path):
REFERRAL_CREDIT_DAYS, convert_referral, REFERRAL_CREDIT_DAYS, convert_referral,
) )
factory = _build_session_factory(tmp_path) factory = db_factory
asyncio.run(_add_pair(factory)) await _add_pair(factory)
async def _run(): async with factory() as s:
async with factory() as s: referred = await s.get(User, 2)
referred = await s.get(User, 2) ref = await convert_referral(s, referred)
ref = await convert_referral(s, referred) assert ref is not None
assert ref is not None assert ref.converted_at is not None
assert ref.converted_at is not None assert ref.credited_at is not None
assert ref.credited_at is not None await s.commit()
await s.commit()
# Re-open a fresh session so we read committed state, not the # Re-open a fresh session so we read committed state, not the
# session-cached version. # session-cached version.
async with factory() as s: async with factory() as s:
referrer = await s.get(User, 1) referrer = await s.get(User, 1)
referred = await s.get(User, 2) referred = await s.get(User, 2)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
# Both windows should sit ~REFERRAL_CREDIT_DAYS in the # Both windows should sit ~REFERRAL_CREDIT_DAYS in the
# future (allow 1 day slack for clock + rounding). # future (allow 1 day slack for clock + rounding).
for u in (referrer, referred): for u in (referrer, referred):
assert u.credit_until is not None assert u.credit_until is not None
cu = u.credit_until cu = u.credit_until
if cu.tzinfo is None: if cu.tzinfo is None:
cu = cu.replace(tzinfo=timezone.utc) cu = cu.replace(tzinfo=timezone.utc)
delta_days = (cu - now).total_seconds() / 86400 delta_days = (cu - now).total_seconds() / 86400
assert REFERRAL_CREDIT_DAYS - 1 <= delta_days <= REFERRAL_CREDIT_DAYS + 1 assert REFERRAL_CREDIT_DAYS - 1 <= delta_days <= REFERRAL_CREDIT_DAYS + 1
asyncio.run(_run())
def test_idempotent_on_repeat_call(tmp_path): async def test_idempotent_on_repeat_call(db_factory):
"""A second convert_referral call (e.g. from a duplicate webhook or """A second convert_referral call (e.g. from a duplicate webhook or
renewal event) must NOT extend credit a second time. The Referral renewal event) must NOT extend credit a second time. The Referral
row is already stamped, so we should early-return unchanged.""" row is already stamped, so we should early-return unchanged."""
from app.models import User from app.models import User
from app.services.referral_service import convert_referral from app.services.referral_service import convert_referral
factory = _build_session_factory(tmp_path) factory = db_factory
asyncio.run(_add_pair(factory)) await _add_pair(factory)
async def _run(): async with factory() as s:
async with factory() as s: referred = await s.get(User, 2)
referred = await s.get(User, 2) await convert_referral(s, referred)
await convert_referral(s, referred) await s.commit()
await s.commit() # Snapshot credit_until after first conversion.
# Snapshot credit_until after first conversion. async with factory() as s:
async with factory() as s: referrer = await s.get(User, 1)
referrer = await s.get(User, 1) referred = await s.get(User, 2)
referred = await s.get(User, 2) first_referrer_credit = referrer.credit_until
first_referrer_credit = referrer.credit_until first_referred_credit = referred.credit_until
first_referred_credit = referred.credit_until
# Second call — should no-op. # Second call — should no-op.
async with factory() as s: async with factory() as s:
referred = await s.get(User, 2) referred = await s.get(User, 2)
ref2 = await convert_referral(s, referred) ref2 = await convert_referral(s, referred)
assert ref2 is not None # we still return the row assert ref2 is not None # we still return the row
await s.commit() await s.commit()
async with factory() as s: async with factory() as s:
referrer = await s.get(User, 1) referrer = await s.get(User, 1)
referred = await s.get(User, 2) referred = await s.get(User, 2)
assert referrer.credit_until == first_referrer_credit assert referrer.credit_until == first_referrer_credit
assert referred.credit_until == first_referred_credit assert referred.credit_until == first_referred_credit
asyncio.run(_run())
def test_no_referral_row_returns_none(tmp_path): async def test_no_referral_row_returns_none(db_factory):
"""A user signing up directly (no inviter) has no Referral row. """A user signing up directly (no inviter) has no Referral row.
convert_referral must return None and touch nothing.""" convert_referral must return None and touch nothing."""
from app.models import User from app.models import User
from app.services.referral_service import convert_referral from app.services.referral_service import convert_referral
factory = _build_session_factory(tmp_path) factory = db_factory
async def _seed_orphan(): async with factory() as s:
async with factory() as s: s.add(User(id=9, email="lone@x", tier="free"))
s.add(User(id=9, email="lone@x", tier="free")) await s.commit()
await s.commit()
asyncio.run(_seed_orphan()) async with factory() as s:
user = await s.get(User, 9)
async def _run(): result = await convert_referral(s, user)
async with factory() as s: assert result is None
user = await s.get(User, 9) assert user.credit_until is None
result = await convert_referral(s, user)
assert result is None
assert user.credit_until is None
asyncio.run(_run())
def test_credit_stacks_from_existing_window(tmp_path): async def test_credit_stacks_from_existing_window(db_factory):
"""If the user already has a future credit_until (admin grant, prior """If the user already has a future credit_until (admin grant, prior
referral), the new credit should extend from THAT anchor not from referral), the new credit should extend from THAT anchor not from
now. Mirrors cli.grant_credit's stacking semantics.""" now. Mirrors cli.grant_credit's stacking semantics."""
@ -179,75 +144,63 @@ def test_credit_stacks_from_existing_window(tmp_path):
REFERRAL_CREDIT_DAYS, convert_referral, REFERRAL_CREDIT_DAYS, convert_referral,
) )
factory = _build_session_factory(tmp_path) factory = db_factory
asyncio.run(_add_pair(factory)) await _add_pair(factory)
# Pre-load 30 days of credit on the referred user. # Pre-load 30 days of credit on the referred user.
existing = datetime.now(timezone.utc) + timedelta(days=30) existing = datetime.now(timezone.utc) + timedelta(days=30)
async def _preload(): async with factory() as s:
async with factory() as s: u = await s.get(User, 2)
u = await s.get(User, 2) u.credit_until = existing
u.credit_until = existing await s.commit()
await s.commit()
asyncio.run(_preload()) async with factory() as s:
referred = await s.get(User, 2)
async def _run(): await convert_referral(s, referred)
async with factory() as s: await s.commit()
referred = await s.get(User, 2) async with factory() as s:
await convert_referral(s, referred) referred = await s.get(User, 2)
await s.commit() cu = referred.credit_until
async with factory() as s: if cu.tzinfo is None:
referred = await s.get(User, 2) cu = cu.replace(tzinfo=timezone.utc)
cu = referred.credit_until # Expected: existing + REFERRAL_CREDIT_DAYS days, not now + days.
if cu.tzinfo is None: expected = existing + timedelta(days=REFERRAL_CREDIT_DAYS)
cu = cu.replace(tzinfo=timezone.utc) delta_seconds = abs((cu - expected).total_seconds())
# Expected: existing + REFERRAL_CREDIT_DAYS days, not now + days. assert delta_seconds < 60, (
expected = existing + timedelta(days=REFERRAL_CREDIT_DAYS) f"new credit anchored at now, not existing window: "
delta_seconds = abs((cu - expected).total_seconds()) f"got {cu}, expected ~{expected}"
assert delta_seconds < 60, ( )
f"new credit anchored at now, not existing window: "
f"got {cu}, expected ~{expected}"
)
asyncio.run(_run())
def test_deleted_referrer_does_not_crash(tmp_path): async def test_deleted_referrer_does_not_crash(db_factory):
"""If the referrer's User row has been deleted, the referred user """If the referrer's User row has been deleted, the referred user
should still be credited and the Referral still stamped we just should still be credited and the Referral still stamped we just
skip the missing referrer.""" skip the missing referrer."""
from app.models import Referral, User from app.models import Referral, User
from app.services.referral_service import convert_referral from app.services.referral_service import convert_referral
factory = _build_session_factory(tmp_path) factory = db_factory
async def _seed(): from app.db import utcnow
from app.db import utcnow async with factory() as s:
async with factory() as s: # Referrer with FK SET NULL — we don't delete the row, we
# Referrer with FK SET NULL — we don't delete the row, we # instead create a Referral pointing at a non-existent id
# instead create a Referral pointing at a non-existent id # to simulate a deleted referrer.
# to simulate a deleted referrer. s.add(User(id=2, email="u2@x", tier="free"))
s.add(User(id=2, email="u2@x", tier="free")) s.add(Referral(referrer_user_id=999, # nonexistent
s.add(Referral(referrer_user_id=999, # nonexistent referred_user_id=2,
referred_user_id=2, created_at=utcnow()))
created_at=utcnow())) await s.commit()
await s.commit()
asyncio.run(_seed()) async with factory() as s:
referred = await s.get(User, 2)
async def _run(): ref = await convert_referral(s, referred)
async with factory() as s: await s.commit()
referred = await s.get(User, 2) assert ref is not None
ref = await convert_referral(s, referred) assert ref.converted_at is not None
await s.commit() # Referred still got their credit even though referrer is gone.
assert ref is not None assert referred.credit_until is not None
assert ref.converted_at is not None
# Referred still got their credit even though referrer is gone.
assert referred.credit_until is not None
asyncio.run(_run())
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -9,34 +9,13 @@ from unittest.mock import AsyncMock
import pytest import pytest
def _build_session_factory(tmp_path):
"""Spin up a fresh in-memory schema and return (engine, factory, setup).
Mirrors tests/test_llm_csv_parser.py / tests/test_referral_conversion.py."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod async def test_validate_happy_path(db_factory, monkeypatch):
from app.db import Base
import app.models # noqa: F401
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/tv.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _setup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, factory, _setup
async def test_validate_happy_path(tmp_path, monkeypatch):
from app.routers.ticker_validate import validate_ticker from app.routers.ticker_validate import validate_ticker
from app.services.market import Quote from app.services.market import Quote
import app.routers.ticker_validate as mod import app.routers.ticker_validate as mod
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
# Mock fetch_yahoo to return a successful quote. # Mock fetch_yahoo to return a successful quote.
async def _fake_yahoo(client, symbol, label, note, anchor=None): async def _fake_yahoo(client, symbol, label, note, anchor=None):
@ -61,13 +40,12 @@ async def test_validate_happy_path(tmp_path, monkeypatch):
assert result["as_of"] == "2026-05-27" assert result["as_of"] == "2026-05-27"
async def test_validate_unknown_symbol(tmp_path, monkeypatch): async def test_validate_unknown_symbol(db_factory, monkeypatch):
from app.routers.ticker_validate import validate_ticker from app.routers.ticker_validate import validate_ticker
from app.services.market import Quote from app.services.market import Quote
import app.routers.ticker_validate as mod import app.routers.ticker_validate as mod
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
# Mock fetch_yahoo to return a Quote with error and no price. # Mock fetch_yahoo to return a Quote with error and no price.
async def _fake_yahoo(client, symbol, label, note, anchor=None): async def _fake_yahoo(client, symbol, label, note, anchor=None):
@ -92,7 +70,7 @@ async def test_validate_empty_symbol_rejects():
assert "required" in result["error"].lower() assert "required" in result["error"].lower()
async def test_validate_seeds_universe_and_quote(tmp_path, monkeypatch): async def test_validate_seeds_universe_and_quote(db_factory, monkeypatch):
"""Side-effect check: on success, the symbol is upserted into the """Side-effect check: on success, the symbol is upserted into the
universe and a Quote row is written.""" universe and a Quote row is written."""
from sqlalchemy import select from sqlalchemy import select
@ -102,8 +80,7 @@ async def test_validate_seeds_universe_and_quote(tmp_path, monkeypatch):
from app.services.market import Quote from app.services.market import Quote
import app.routers.ticker_validate as mod import app.routers.ticker_validate as mod
_, factory, setup = _build_session_factory(tmp_path) factory = db_factory
await setup()
upsert_calls: list[list[str]] = [] upsert_calls: list[list[str]] = []