tests: extract _build_session_factory to a shared conftest fixture
The same per-test sqlite-engine setup was duplicated across 14 test files (~30 lines each). Consolidated into a single async fixture `db_factory` in tests/conftest.py; tests now take db_factory as a parameter and use `async with db_factory() as session` directly. No behaviour change — same function-scope, same in-memory schema created via Base.metadata.create_all, same app.db._engine / _session_factory rebinding so module-level helpers see the test engine. Just ~420 lines of boilerplate removed.
This commit is contained in:
parent
b13caa4c51
commit
dcc2c07111
5 changed files with 167 additions and 250 deletions
|
|
@ -17,3 +17,40 @@ sys.path.insert(0, str(ROOT))
|
||||||
# Sentinel env so importing app.config doesn't try to read a missing .env.
|
# Sentinel env so importing app.config doesn't try to read a missing .env.
|
||||||
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
||||||
os.environ.setdefault("CASSANDRA_MOCK", "1")
|
os.environ.setdefault("CASSANDRA_MOCK", "1")
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_factory(tmp_path):
|
||||||
|
"""Per-test sqlite engine + async session factory.
|
||||||
|
|
||||||
|
Creates a fresh sqlite database file under ``tmp_path``, applies
|
||||||
|
``Base.metadata.create_all``, and rebinds ``app.db._engine`` /
|
||||||
|
``app.db._session_factory`` so module-level helpers (which look
|
||||||
|
these up at call time) see the test engine.
|
||||||
|
|
||||||
|
Yields the ``async_sessionmaker``. Tests use it like:
|
||||||
|
|
||||||
|
async def test_foo(db_factory):
|
||||||
|
async with db_factory() as session:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app import db as db_mod
|
||||||
|
from app.db import Base
|
||||||
|
import app.models # noqa: F401 — registers models on Base.metadata
|
||||||
|
|
||||||
|
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/test.db")
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
db_mod._engine = engine
|
||||||
|
db_mod._session_factory = factory
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield factory
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
|
||||||
|
|
@ -4,26 +4,6 @@ from __future__ import annotations
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def _build_session_factory(tmp_path):
|
|
||||||
"""Spin up a fresh in-memory schema and return (engine, factory).
|
|
||||||
Matches the pattern used in tests/test_referral_conversion.py."""
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app import db as db_mod
|
|
||||||
from app.db import Base
|
|
||||||
import app.models # noqa: F401 — registers models on Base.metadata
|
|
||||||
|
|
||||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/csv.db")
|
|
||||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
db_mod._engine = engine
|
|
||||||
db_mod._session_factory = factory
|
|
||||||
|
|
||||||
async def _setup():
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
return engine, factory, _setup
|
|
||||||
|
|
||||||
|
|
||||||
def test_csv_format_template_model_columns():
|
def test_csv_format_template_model_columns():
|
||||||
"""Model exposes every column the spec requires, with correct types."""
|
"""Model exposes every column the spec requires, with correct types."""
|
||||||
|
|
@ -310,7 +290,7 @@ async def test_extract_mapping_via_llm_provider_failure_wraps():
|
||||||
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
|
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
|
||||||
|
|
||||||
|
|
||||||
async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
|
async def test_parse_with_llm_cache_miss_inserts_template(db_factory):
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
@ -318,8 +298,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
|
||||||
from app.services.llm_csv_parser import parse_with_llm
|
from app.services.llm_csv_parser import parse_with_llm
|
||||||
from app.services.openrouter import LogResult
|
from app.services.openrouter import LogResult
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
raw = (
|
raw = (
|
||||||
b"Symbol,Quantity,Avg Price,Currency\n"
|
b"Symbol,Quantity,Avg Price,Currency\n"
|
||||||
|
|
@ -356,7 +335,7 @@ async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
|
||||||
assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user"
|
assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user"
|
||||||
|
|
||||||
|
|
||||||
async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
|
async def test_parse_with_llm_cache_hit_skips_llm(db_factory):
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
@ -364,8 +343,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
|
||||||
from app.models import CsvFormatTemplate
|
from app.models import CsvFormatTemplate
|
||||||
from app.services.llm_csv_parser import _fingerprint, parse_with_llm
|
from app.services.llm_csv_parser import _fingerprint, parse_with_llm
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
|
headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
|
||||||
fp = _fingerprint(headers)
|
fp = _fingerprint(headers)
|
||||||
|
|
@ -411,7 +389,7 @@ async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
|
||||||
assert rows[0].use_count == 2
|
assert rows[0].use_count == 2
|
||||||
|
|
||||||
|
|
||||||
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
|
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(db_factory):
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
@ -419,8 +397,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
|
||||||
from app.models import CsvFormatTemplate
|
from app.models import CsvFormatTemplate
|
||||||
from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm
|
from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
headers = ["Symbol", "Quantity"]
|
headers = ["Symbol", "Quantity"]
|
||||||
fp = _fingerprint(headers)
|
fp = _fingerprint(headers)
|
||||||
|
|
@ -452,7 +429,7 @@ async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch):
|
async def test_parse_portfolio_route_falls_through_to_llm(db_factory, monkeypatch):
|
||||||
"""End-to-end: T212 parser raises CSVImportError, LLM fallback runs,
|
"""End-to-end: T212 parser raises CSVImportError, LLM fallback runs,
|
||||||
response shape matches the existing JSON contract."""
|
response shape matches the existing JSON contract."""
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
@ -461,8 +438,7 @@ async def test_parse_portfolio_route_falls_through_to_llm(tmp_path, monkeypatch)
|
||||||
|
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
import app.services.llm_csv_parser as mod
|
import app.services.llm_csv_parser as mod
|
||||||
from app.services.openrouter import LogResult
|
from app.services.openrouter import LogResult
|
||||||
|
|
|
||||||
|
|
@ -5,25 +5,6 @@ from __future__ import annotations
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def _build_session_factory(tmp_path):
|
|
||||||
"""Per-test sqlite engine + factory. Mirrors test_referral_conversion.py."""
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app import db as db_mod
|
|
||||||
from app.db import Base
|
|
||||||
import app.models # noqa: F401 — registers models on Base.metadata
|
|
||||||
|
|
||||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/loc.db")
|
|
||||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
db_mod._engine = engine
|
|
||||||
db_mod._session_factory = factory
|
|
||||||
|
|
||||||
async def _setup():
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
return engine, factory, _setup
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_has_lang_column_with_default_en():
|
def test_user_has_lang_column_with_default_en():
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
|
|
@ -55,7 +36,7 @@ def test_strategic_log_translation_model_columns():
|
||||||
assert cols["content_md"].nullable is False
|
assert cols["content_md"].nullable is False
|
||||||
|
|
||||||
|
|
||||||
async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypatch):
|
async def test_log_translation_fanout_no_active_non_en_users(db_factory, monkeypatch):
|
||||||
"""When no users have an active non-en lang, the fan-out makes no
|
"""When no users have an active non-en lang, the fan-out makes no
|
||||||
translation calls and no rows are inserted."""
|
translation calls and no rows are inserted."""
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
@ -65,8 +46,7 @@ async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypat
|
||||||
from app.models import StrategicLog, StrategicLogTranslation, User
|
from app.models import StrategicLog, StrategicLogTranslation, User
|
||||||
from app.jobs import ai_log_job
|
from app.jobs import ai_log_job
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
fake_translate = AsyncMock()
|
fake_translate = AsyncMock()
|
||||||
monkeypatch.setattr(ai_log_job, "translate", fake_translate)
|
monkeypatch.setattr(ai_log_job, "translate", fake_translate)
|
||||||
|
|
@ -92,7 +72,7 @@ async def test_log_translation_fanout_no_active_non_en_users(tmp_path, monkeypat
|
||||||
assert rows == []
|
assert rows == []
|
||||||
|
|
||||||
|
|
||||||
async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
|
async def test_log_translation_fanout_italian_user(db_factory, monkeypatch):
|
||||||
"""One user at lang=it triggers one translation; the row lands with
|
"""One user at lang=it triggers one translation; the row lands with
|
||||||
the right lang and log_id."""
|
the right lang and log_id."""
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -102,8 +82,7 @@ async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
|
||||||
from app.services.openrouter import LogResult
|
from app.services.openrouter import LogResult
|
||||||
from app.jobs import ai_log_job
|
from app.jobs import ai_log_job
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
async def _fake_translate(client, text, target_lang):
|
async def _fake_translate(client, text, target_lang):
|
||||||
assert target_lang == "it"
|
assert target_lang == "it"
|
||||||
|
|
@ -139,7 +118,7 @@ async def test_log_translation_fanout_italian_user(tmp_path, monkeypatch):
|
||||||
assert row.llm_cost_usd == pytest.approx(0.00002)
|
assert row.llm_cost_usd == pytest.approx(0.00002)
|
||||||
|
|
||||||
|
|
||||||
async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, monkeypatch):
|
async def test_log_translation_fanout_per_language_failure_isolated(db_factory, monkeypatch):
|
||||||
"""If one language's translation fails, the others (if any) still land
|
"""If one language's translation fails, the others (if any) still land
|
||||||
and the job does not raise."""
|
and the job does not raise."""
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -148,8 +127,7 @@ async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, mo
|
||||||
from app.models import StrategicLog, StrategicLogTranslation, User
|
from app.models import StrategicLog, StrategicLogTranslation, User
|
||||||
from app.jobs import ai_log_job
|
from app.jobs import ai_log_job
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
async def _fake_translate(client, text, target_lang):
|
async def _fake_translate(client, text, target_lang):
|
||||||
raise RuntimeError("upstream down")
|
raise RuntimeError("upstream down")
|
||||||
|
|
@ -175,7 +153,7 @@ async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, mo
|
||||||
assert rows == []
|
assert rows == []
|
||||||
|
|
||||||
|
|
||||||
async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
|
async def test_analyse_threads_lang_into_system_prompt(db_factory, monkeypatch):
|
||||||
"""When lang='it', the system prompt sent to call_llm contains
|
"""When lang='it', the system prompt sent to call_llm contains
|
||||||
'Respond in Italian.' — the LLM does the rest."""
|
'Respond in Italian.' — the LLM does the rest."""
|
||||||
from app.services import portfolio_analysis as pa
|
from app.services import portfolio_analysis as pa
|
||||||
|
|
@ -191,8 +169,7 @@ async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
|
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
|
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
|
||||||
|
|
@ -213,7 +190,7 @@ async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
|
||||||
assert "Respond in Italian" in system
|
assert "Respond in Italian" in system
|
||||||
|
|
||||||
|
|
||||||
async def test_analyse_no_clause_when_lang_is_en(tmp_path, monkeypatch):
|
async def test_analyse_no_clause_when_lang_is_en(db_factory, monkeypatch):
|
||||||
from app.services import portfolio_analysis as pa
|
from app.services import portfolio_analysis as pa
|
||||||
from app.services.openrouter import LogResult
|
from app.services.openrouter import LogResult
|
||||||
|
|
||||||
|
|
@ -227,8 +204,7 @@ async def test_analyse_no_clause_when_lang_is_en(tmp_path, monkeypatch):
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
|
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
|
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
|
||||||
|
|
@ -328,13 +304,12 @@ def test_digest_pick_variant_uses_user_lang():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def test_patch_language_accepts_active(tmp_path):
|
async def test_patch_language_accepts_active(db_factory):
|
||||||
"""PATCH /api/settings/language accepts 'en' and 'it' and persists."""
|
"""PATCH /api/settings/language accepts 'en' and 'it' and persists."""
|
||||||
from app.models import User
|
from app.models import User
|
||||||
from app.routers.api import patch_language_prefs, LanguagePrefsIn
|
from app.routers.api import patch_language_prefs, LanguagePrefsIn
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
async with factory() as session:
|
async with factory() as session:
|
||||||
session.add(User(id=20, email="u@x", tier="paid", lang="en"))
|
session.add(User(id=20, email="u@x", tier="paid", lang="en"))
|
||||||
|
|
@ -358,14 +333,13 @@ async def test_patch_language_accepts_active(tmp_path):
|
||||||
assert user.lang == "it"
|
assert user.lang == "it"
|
||||||
|
|
||||||
|
|
||||||
async def test_patch_language_rejects_wip(tmp_path):
|
async def test_patch_language_rejects_wip(db_factory):
|
||||||
"""PATCH rejects 'es'/'fr'/'de'/'xx' with 400 — ACTIVE_LANGUAGES gate."""
|
"""PATCH rejects 'es'/'fr'/'de'/'xx' with 400 — ACTIVE_LANGUAGES gate."""
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from app.models import User
|
from app.models import User
|
||||||
from app.routers.api import patch_language_prefs, LanguagePrefsIn
|
from app.routers.api import patch_language_prefs, LanguagePrefsIn
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
async with factory() as session:
|
async with factory() as session:
|
||||||
session.add(User(id=21, email="u2@x", tier="paid", lang="en"))
|
session.add(User(id=21, email="u2@x", tier="paid", lang="en"))
|
||||||
|
|
|
||||||
|
|
@ -23,29 +23,6 @@ import pytest
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _build_session_factory(tmp_path):
|
|
||||||
"""Spin up a fresh in-memory schema and return (engine, factory).
|
|
||||||
Mirrors test_stripe_billing._build_app's seeding strategy but
|
|
||||||
skips the FastAPI app — most conversion tests only need the
|
|
||||||
session factory."""
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app import db as db_mod
|
|
||||||
from app.db import Base
|
|
||||||
|
|
||||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/conv.db")
|
|
||||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
db_mod._engine = engine
|
|
||||||
db_mod._session_factory = factory
|
|
||||||
|
|
||||||
async def _seed():
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
asyncio.run(_seed())
|
|
||||||
return factory
|
|
||||||
|
|
||||||
|
|
||||||
async def _add_pair(factory, *, referrer_id=1, referred_id=2):
|
async def _add_pair(factory, *, referrer_id=1, referred_id=2):
|
||||||
"""Insert a referrer + referred user pair and a linking Referral row.
|
"""Insert a referrer + referred user pair and a linking Referral row.
|
||||||
Returns nothing — tests re-fetch via the factory."""
|
Returns nothing — tests re-fetch via the factory."""
|
||||||
|
|
@ -68,7 +45,7 @@ async def _add_pair(factory, *, referrer_id=1, referred_id=2):
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def test_first_conversion_credits_both_parties(tmp_path):
|
async def test_first_conversion_credits_both_parties(db_factory):
|
||||||
"""Calling convert_referral on a freshly-paid referred user should
|
"""Calling convert_referral on a freshly-paid referred user should
|
||||||
extend credit_until by REFERRAL_CREDIT_DAYS for BOTH the buyer and
|
extend credit_until by REFERRAL_CREDIT_DAYS for BOTH the buyer and
|
||||||
the referrer, and stamp converted_at + credited_at."""
|
the referrer, and stamp converted_at + credited_at."""
|
||||||
|
|
@ -77,10 +54,9 @@ def test_first_conversion_credits_both_parties(tmp_path):
|
||||||
REFERRAL_CREDIT_DAYS, convert_referral,
|
REFERRAL_CREDIT_DAYS, convert_referral,
|
||||||
)
|
)
|
||||||
|
|
||||||
factory = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
asyncio.run(_add_pair(factory))
|
await _add_pair(factory)
|
||||||
|
|
||||||
async def _run():
|
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
referred = await s.get(User, 2)
|
referred = await s.get(User, 2)
|
||||||
ref = await convert_referral(s, referred)
|
ref = await convert_referral(s, referred)
|
||||||
|
|
@ -105,20 +81,17 @@ def test_first_conversion_credits_both_parties(tmp_path):
|
||||||
delta_days = (cu - now).total_seconds() / 86400
|
delta_days = (cu - now).total_seconds() / 86400
|
||||||
assert REFERRAL_CREDIT_DAYS - 1 <= delta_days <= REFERRAL_CREDIT_DAYS + 1
|
assert REFERRAL_CREDIT_DAYS - 1 <= delta_days <= REFERRAL_CREDIT_DAYS + 1
|
||||||
|
|
||||||
asyncio.run(_run())
|
|
||||||
|
|
||||||
|
async def test_idempotent_on_repeat_call(db_factory):
|
||||||
def test_idempotent_on_repeat_call(tmp_path):
|
|
||||||
"""A second convert_referral call (e.g. from a duplicate webhook or
|
"""A second convert_referral call (e.g. from a duplicate webhook or
|
||||||
renewal event) must NOT extend credit a second time. The Referral
|
renewal event) must NOT extend credit a second time. The Referral
|
||||||
row is already stamped, so we should early-return unchanged."""
|
row is already stamped, so we should early-return unchanged."""
|
||||||
from app.models import User
|
from app.models import User
|
||||||
from app.services.referral_service import convert_referral
|
from app.services.referral_service import convert_referral
|
||||||
|
|
||||||
factory = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
asyncio.run(_add_pair(factory))
|
await _add_pair(factory)
|
||||||
|
|
||||||
async def _run():
|
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
referred = await s.get(User, 2)
|
referred = await s.get(User, 2)
|
||||||
await convert_referral(s, referred)
|
await convert_referral(s, referred)
|
||||||
|
|
@ -142,35 +115,27 @@ def test_idempotent_on_repeat_call(tmp_path):
|
||||||
assert referrer.credit_until == first_referrer_credit
|
assert referrer.credit_until == first_referrer_credit
|
||||||
assert referred.credit_until == first_referred_credit
|
assert referred.credit_until == first_referred_credit
|
||||||
|
|
||||||
asyncio.run(_run())
|
|
||||||
|
|
||||||
|
async def test_no_referral_row_returns_none(db_factory):
|
||||||
def test_no_referral_row_returns_none(tmp_path):
|
|
||||||
"""A user signing up directly (no inviter) has no Referral row.
|
"""A user signing up directly (no inviter) has no Referral row.
|
||||||
convert_referral must return None and touch nothing."""
|
convert_referral must return None and touch nothing."""
|
||||||
from app.models import User
|
from app.models import User
|
||||||
from app.services.referral_service import convert_referral
|
from app.services.referral_service import convert_referral
|
||||||
|
|
||||||
factory = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
|
|
||||||
async def _seed_orphan():
|
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
s.add(User(id=9, email="lone@x", tier="free"))
|
s.add(User(id=9, email="lone@x", tier="free"))
|
||||||
await s.commit()
|
await s.commit()
|
||||||
|
|
||||||
asyncio.run(_seed_orphan())
|
|
||||||
|
|
||||||
async def _run():
|
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
user = await s.get(User, 9)
|
user = await s.get(User, 9)
|
||||||
result = await convert_referral(s, user)
|
result = await convert_referral(s, user)
|
||||||
assert result is None
|
assert result is None
|
||||||
assert user.credit_until is None
|
assert user.credit_until is None
|
||||||
|
|
||||||
asyncio.run(_run())
|
|
||||||
|
|
||||||
|
async def test_credit_stacks_from_existing_window(db_factory):
|
||||||
def test_credit_stacks_from_existing_window(tmp_path):
|
|
||||||
"""If the user already has a future credit_until (admin grant, prior
|
"""If the user already has a future credit_until (admin grant, prior
|
||||||
referral), the new credit should extend from THAT anchor — not from
|
referral), the new credit should extend from THAT anchor — not from
|
||||||
now. Mirrors cli.grant_credit's stacking semantics."""
|
now. Mirrors cli.grant_credit's stacking semantics."""
|
||||||
|
|
@ -179,21 +144,17 @@ def test_credit_stacks_from_existing_window(tmp_path):
|
||||||
REFERRAL_CREDIT_DAYS, convert_referral,
|
REFERRAL_CREDIT_DAYS, convert_referral,
|
||||||
)
|
)
|
||||||
|
|
||||||
factory = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
asyncio.run(_add_pair(factory))
|
await _add_pair(factory)
|
||||||
|
|
||||||
# Pre-load 30 days of credit on the referred user.
|
# Pre-load 30 days of credit on the referred user.
|
||||||
existing = datetime.now(timezone.utc) + timedelta(days=30)
|
existing = datetime.now(timezone.utc) + timedelta(days=30)
|
||||||
|
|
||||||
async def _preload():
|
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
u = await s.get(User, 2)
|
u = await s.get(User, 2)
|
||||||
u.credit_until = existing
|
u.credit_until = existing
|
||||||
await s.commit()
|
await s.commit()
|
||||||
|
|
||||||
asyncio.run(_preload())
|
|
||||||
|
|
||||||
async def _run():
|
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
referred = await s.get(User, 2)
|
referred = await s.get(User, 2)
|
||||||
await convert_referral(s, referred)
|
await convert_referral(s, referred)
|
||||||
|
|
@ -211,19 +172,16 @@ def test_credit_stacks_from_existing_window(tmp_path):
|
||||||
f"got {cu}, expected ~{expected}"
|
f"got {cu}, expected ~{expected}"
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(_run())
|
|
||||||
|
|
||||||
|
async def test_deleted_referrer_does_not_crash(db_factory):
|
||||||
def test_deleted_referrer_does_not_crash(tmp_path):
|
|
||||||
"""If the referrer's User row has been deleted, the referred user
|
"""If the referrer's User row has been deleted, the referred user
|
||||||
should still be credited and the Referral still stamped — we just
|
should still be credited and the Referral still stamped — we just
|
||||||
skip the missing referrer."""
|
skip the missing referrer."""
|
||||||
from app.models import Referral, User
|
from app.models import Referral, User
|
||||||
from app.services.referral_service import convert_referral
|
from app.services.referral_service import convert_referral
|
||||||
|
|
||||||
factory = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
|
|
||||||
async def _seed():
|
|
||||||
from app.db import utcnow
|
from app.db import utcnow
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
# Referrer with FK SET NULL — we don't delete the row, we
|
# Referrer with FK SET NULL — we don't delete the row, we
|
||||||
|
|
@ -235,9 +193,6 @@ def test_deleted_referrer_does_not_crash(tmp_path):
|
||||||
created_at=utcnow()))
|
created_at=utcnow()))
|
||||||
await s.commit()
|
await s.commit()
|
||||||
|
|
||||||
asyncio.run(_seed())
|
|
||||||
|
|
||||||
async def _run():
|
|
||||||
async with factory() as s:
|
async with factory() as s:
|
||||||
referred = await s.get(User, 2)
|
referred = await s.get(User, 2)
|
||||||
ref = await convert_referral(s, referred)
|
ref = await convert_referral(s, referred)
|
||||||
|
|
@ -247,8 +202,6 @@ def test_deleted_referrer_does_not_crash(tmp_path):
|
||||||
# Referred still got their credit even though referrer is gone.
|
# Referred still got their credit even though referrer is gone.
|
||||||
assert referred.credit_until is not None
|
assert referred.credit_until is not None
|
||||||
|
|
||||||
asyncio.run(_run())
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Stripe-webhook integration
|
# Stripe-webhook integration
|
||||||
|
|
|
||||||
|
|
@ -9,34 +9,13 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def _build_session_factory(tmp_path):
|
|
||||||
"""Spin up a fresh in-memory schema and return (engine, factory, setup).
|
|
||||||
Mirrors tests/test_llm_csv_parser.py / tests/test_referral_conversion.py."""
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app import db as db_mod
|
async def test_validate_happy_path(db_factory, monkeypatch):
|
||||||
from app.db import Base
|
|
||||||
import app.models # noqa: F401
|
|
||||||
|
|
||||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/tv.db")
|
|
||||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
db_mod._engine = engine
|
|
||||||
db_mod._session_factory = factory
|
|
||||||
|
|
||||||
async def _setup():
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
return engine, factory, _setup
|
|
||||||
|
|
||||||
|
|
||||||
async def test_validate_happy_path(tmp_path, monkeypatch):
|
|
||||||
from app.routers.ticker_validate import validate_ticker
|
from app.routers.ticker_validate import validate_ticker
|
||||||
from app.services.market import Quote
|
from app.services.market import Quote
|
||||||
import app.routers.ticker_validate as mod
|
import app.routers.ticker_validate as mod
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
# Mock fetch_yahoo to return a successful quote.
|
# Mock fetch_yahoo to return a successful quote.
|
||||||
async def _fake_yahoo(client, symbol, label, note, anchor=None):
|
async def _fake_yahoo(client, symbol, label, note, anchor=None):
|
||||||
|
|
@ -61,13 +40,12 @@ async def test_validate_happy_path(tmp_path, monkeypatch):
|
||||||
assert result["as_of"] == "2026-05-27"
|
assert result["as_of"] == "2026-05-27"
|
||||||
|
|
||||||
|
|
||||||
async def test_validate_unknown_symbol(tmp_path, monkeypatch):
|
async def test_validate_unknown_symbol(db_factory, monkeypatch):
|
||||||
from app.routers.ticker_validate import validate_ticker
|
from app.routers.ticker_validate import validate_ticker
|
||||||
from app.services.market import Quote
|
from app.services.market import Quote
|
||||||
import app.routers.ticker_validate as mod
|
import app.routers.ticker_validate as mod
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
# Mock fetch_yahoo to return a Quote with error and no price.
|
# Mock fetch_yahoo to return a Quote with error and no price.
|
||||||
async def _fake_yahoo(client, symbol, label, note, anchor=None):
|
async def _fake_yahoo(client, symbol, label, note, anchor=None):
|
||||||
|
|
@ -92,7 +70,7 @@ async def test_validate_empty_symbol_rejects():
|
||||||
assert "required" in result["error"].lower()
|
assert "required" in result["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
async def test_validate_seeds_universe_and_quote(tmp_path, monkeypatch):
|
async def test_validate_seeds_universe_and_quote(db_factory, monkeypatch):
|
||||||
"""Side-effect check: on success, the symbol is upserted into the
|
"""Side-effect check: on success, the symbol is upserted into the
|
||||||
universe and a Quote row is written."""
|
universe and a Quote row is written."""
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -102,8 +80,7 @@ async def test_validate_seeds_universe_and_quote(tmp_path, monkeypatch):
|
||||||
from app.services.market import Quote
|
from app.services.market import Quote
|
||||||
import app.routers.ticker_validate as mod
|
import app.routers.ticker_validate as mod
|
||||||
|
|
||||||
_, factory, setup = _build_session_factory(tmp_path)
|
factory = db_factory
|
||||||
await setup()
|
|
||||||
|
|
||||||
upsert_calls: list[list[str]] = []
|
upsert_calls: list[list[str]] = []
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue