diff --git a/app/services/llm_csv_parser.py b/app/services/llm_csv_parser.py index 875d863..61c428e 100644 --- a/app/services/llm_csv_parser.py +++ b/app/services/llm_csv_parser.py @@ -23,7 +23,12 @@ import io import json import httpx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from app.db import utcnow +from app.logging import get_logger +from app.models import CsvFormatTemplate from app.services.csv_import import CSVImportError, ParsedPie, ParsedPosition from app.services.openrouter import LogResult, call_llm @@ -43,6 +48,11 @@ _LLM_MAX_TOKENS = 400 _REQUIRED_MAPPING_KEYS = ("ticker_col", "qty_col") _OPTIONAL_MAPPING_KEYS = ("name_col", "cost_col", "currency_col") +# Maximum CSV payload size accepted by parse_with_llm. +_MAX_CSV_BYTES = 1_048_576 + +log = get_logger("llm_csv_parser") + _SYSTEM_PROMPT = """\ You are an expert at recognising broker portfolio CSV formats. @@ -329,3 +339,87 @@ async def _extract_mapping_via_llm( if not isinstance(mapping, dict): raise LLMParseError("LLM JSON was not an object") return mapping, result + + +async def parse_with_llm(raw: bytes, session: AsyncSession) -> ParsedPie: + """Cache-first LLM-fallback CSV parse. + + On cache hit, applies the stored mapping deterministically and + increments ``use_count``. On cache miss, calls the LLM, validates + the returned mapping against the first data row, and persists a + new ``CsvFormatTemplate``. Raises ``LLMParseError`` on any + failure; the caller (route layer) maps that to a 400.""" + if len(raw) > _MAX_CSV_BYTES: + raise LLMParseError("CSV too large (1 MB max)") + if not raw or not raw.strip(): + raise LLMParseError("empty CSV") + + delimiter, preamble_rows = _detect_dialect(raw) + text = _decode_raw(raw) + + reader = csv.reader(io.StringIO(text), delimiter=delimiter) + rows = list(reader) + if preamble_rows >= len(rows): + raise LLMParseError("no header row found in CSV") + headers = [c.strip() for c in rows[preamble_rows]] + data_rows = rows[preamble_rows + 1:] + if not headers: + raise LLMParseError("empty header row") + + first_data_row = next( + (r for r in data_rows if any(c.strip() for c in r)), None, + ) + if first_data_row is None: + raise LLMParseError("CSV contains a header but no data rows") + + fp = _fingerprint(headers) + existing = (await session.execute( + select(CsvFormatTemplate).where(CsvFormatTemplate.fingerprint == fp) + )).scalar_one_or_none() + + if existing is not None: + log.info("csv.format.cache_hit", fingerprint=fp, + broker_label=existing.broker_label, use_count=existing.use_count) + pie = _apply_mapping(headers, data_rows, existing.mapping) + if not pie.positions: + raise LLMParseError( + "cached mapping produced no positions — the broker may have " + "changed their CSV shape; ask the operator to evict the " + "stale template" + ) + existing.use_count += 1 + existing.last_used_at = utcnow() + await session.commit() + return pie + + log.info("csv.format.cache_miss", fingerprint=fp, + header_count=len(headers)) + samples = [r for r in data_rows[:_LLM_SAMPLES] if any(c.strip() for c in r)] + async with httpx.AsyncClient(follow_redirects=True, timeout=30) as client: + mapping, llm_log = await _extract_mapping_via_llm(client, headers, samples) + _validate_mapping(mapping, headers, first_data_row) + + pie = _apply_mapping(headers, data_rows, mapping) + if not pie.positions: + raise LLMParseError( + "LLM mapping validated but produced no positions — the file " + "may not contain portfolio data" + ) + + now = utcnow() + session.add(CsvFormatTemplate( + fingerprint=fp, + headers=headers, + sample_row=first_data_row, + mapping=mapping, + preamble_rows=preamble_rows, + delimiter=delimiter, + broker_label=mapping.get("broker_label"), + first_seen_at=now, + last_used_at=now, + use_count=1, + llm_model=llm_log.model, + llm_cost_usd=llm_log.cost_usd, + )) + await session.commit() + return pie diff --git a/tests/test_llm_csv_parser.py b/tests/test_llm_csv_parser.py index 42eee89..e969c3e 100644 --- a/tests/test_llm_csv_parser.py +++ b/tests/test_llm_csv_parser.py @@ -4,6 +4,27 @@ from __future__ import annotations import pytest +def _build_session_factory(tmp_path): + """Spin up a fresh in-memory schema and return (engine, factory). + Matches the pattern used in tests/test_referral_conversion.py.""" + from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + + from app import db as db_mod + from app.db import Base + import app.models # noqa: F401 — registers models on Base.metadata + + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/csv.db") + factory = async_sessionmaker(engine, expire_on_commit=False) + db_mod._engine = engine + db_mod._session_factory = factory + + async def _setup(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + return engine, factory, _setup + + def test_csv_format_template_model_columns(): """Model exposes every column the spec requires, with correct types.""" from sqlalchemy import inspect @@ -290,3 +311,148 @@ async def test_extract_mapping_via_llm_provider_failure_wraps(): with pytest.raises(LLMParseError, match="provider"): await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]]) + + +@pytest.mark.asyncio +async def test_parse_with_llm_cache_miss_inserts_template(tmp_path): + from unittest.mock import AsyncMock + from sqlalchemy import select + + from app.models import CsvFormatTemplate + from app.services.llm_csv_parser import parse_with_llm + from app.services.openrouter import LogResult + + _, factory, setup = _build_session_factory(tmp_path) + await setup() + + raw = ( + b"Symbol,Quantity,Avg Price,Currency\n" + b"AAPL,100,150.25,USD\n" + b"MSFT,50,310.00,USD\n" + ) + + import app.services.llm_csv_parser as mod + mod.call_llm = AsyncMock(return_value=LogResult( + content='{"ticker_col":"Symbol","qty_col":"Quantity",' + '"cost_col":"Avg Price","currency_col":"Currency",' + '"name_col":null,"broker_label":"Generic broker"}', + model="deepseek/deepseek-v4-flash", + prompt_tokens=120, completion_tokens=40, cost_usd=0.0002, + )) + + async with factory() as session: + pie = await parse_with_llm(raw, session) + + assert len(pie.positions) == 2 + assert pie.positions[0].slice == "AAPL" + + async with factory() as session: + rows = (await session.execute(select(CsvFormatTemplate))).scalars().all() + assert len(rows) == 1 + tmpl = rows[0] + assert tmpl.headers == ["Symbol", "Quantity", "Avg Price", "Currency"] + assert tmpl.sample_row == ["AAPL", "100", "150.25", "USD"] + assert tmpl.mapping["ticker_col"] == "Symbol" + assert tmpl.broker_label == "Generic broker" + assert tmpl.use_count == 1 + assert tmpl.llm_cost_usd == pytest.approx(0.0002) + # The crucial PII guarantee: + assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user" + + +@pytest.mark.asyncio +async def test_parse_with_llm_cache_hit_skips_llm(tmp_path): + from unittest.mock import AsyncMock + from sqlalchemy import select + + from app.db import utcnow + from app.models import CsvFormatTemplate + from app.services.llm_csv_parser import _fingerprint, parse_with_llm + + _, factory, setup = _build_session_factory(tmp_path) + await setup() + + headers = ["Symbol", "Quantity", "Avg Price", "Currency"] + fp = _fingerprint(headers) + + # Pre-populate a cache hit row. + async with factory() as session: + session.add(CsvFormatTemplate( + fingerprint=fp, + headers=headers, + sample_row=["AAPL", "100", "150.25", "USD"], + mapping={ + "ticker_col": "Symbol", "qty_col": "Quantity", + "cost_col": "Avg Price", "currency_col": "Currency", + "name_col": None, + }, + preamble_rows=0, + delimiter=",", + broker_label="Cached broker", + first_seen_at=utcnow(), + last_used_at=utcnow(), + use_count=1, + llm_model="seed", + llm_cost_usd=0.0, + )) + await session.commit() + + raw = ( + b"Symbol,Quantity,Avg Price,Currency\n" + b"NVDA,40,425.50,USD\n" + ) + + import app.services.llm_csv_parser as mod + mod.call_llm = AsyncMock(side_effect=AssertionError("call_llm must NOT be called on cache hit")) + + async with factory() as session: + pie = await parse_with_llm(raw, session) + + assert pie.positions[0].slice == "NVDA" + + async with factory() as session: + rows = (await session.execute(select(CsvFormatTemplate))).scalars().all() + assert len(rows) == 1 + assert rows[0].use_count == 2 + + +@pytest.mark.asyncio +async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path): + from unittest.mock import AsyncMock + from sqlalchemy import select + + from app.db import utcnow + from app.models import CsvFormatTemplate + from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm + + _, factory, setup = _build_session_factory(tmp_path) + await setup() + + headers = ["Symbol", "Quantity"] + fp = _fingerprint(headers) + # Cached mapping says qty is in column "Symbol" — clearly wrong; will + # never produce a parseable row. + async with factory() as session: + session.add(CsvFormatTemplate( + fingerprint=fp, headers=headers, + sample_row=["AAPL", "100"], + mapping={"ticker_col": "Symbol", "qty_col": "Symbol"}, + preamble_rows=0, delimiter=",", broker_label=None, + first_seen_at=utcnow(), last_used_at=utcnow(), use_count=1, + llm_model="seed", llm_cost_usd=0.0, + )) + await session.commit() + + raw = b"Symbol,Quantity\nAAPL,100\nMSFT,50\n" + + import app.services.llm_csv_parser as mod + mod.call_llm = AsyncMock(side_effect=AssertionError("must not be called")) + + async with factory() as session: + with pytest.raises(LLMParseError): + await parse_with_llm(raw, session) + + # Stale template must NOT have been auto-deleted (operator owns eviction). + async with factory() as session: + rows = (await session.execute(select(CsvFormatTemplate))).scalars().all() + assert len(rows) == 1