csv-parser: add public parse_with_llm with cache hit/miss orchestration
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
c77b3564f3
commit
59b28506df
2 changed files with 260 additions and 0 deletions
|
|
@ -23,7 +23,12 @@ import io
|
|||
import json
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import utcnow
|
||||
from app.logging import get_logger
|
||||
from app.models import CsvFormatTemplate
|
||||
from app.services.csv_import import CSVImportError, ParsedPie, ParsedPosition
|
||||
from app.services.openrouter import LogResult, call_llm
|
||||
|
||||
|
|
@ -43,6 +48,11 @@ _LLM_MAX_TOKENS = 400
|
|||
_REQUIRED_MAPPING_KEYS = ("ticker_col", "qty_col")
|
||||
_OPTIONAL_MAPPING_KEYS = ("name_col", "cost_col", "currency_col")
|
||||
|
||||
# Maximum CSV payload size accepted by parse_with_llm.
|
||||
_MAX_CSV_BYTES = 1_048_576
|
||||
|
||||
log = get_logger("llm_csv_parser")
|
||||
|
||||
|
||||
_SYSTEM_PROMPT = """\
|
||||
You are an expert at recognising broker portfolio CSV formats.
|
||||
|
|
@ -329,3 +339,87 @@ async def _extract_mapping_via_llm(
|
|||
if not isinstance(mapping, dict):
|
||||
raise LLMParseError("LLM JSON was not an object")
|
||||
return mapping, result
|
||||
|
||||
|
||||
async def parse_with_llm(raw: bytes, session: AsyncSession) -> ParsedPie:
|
||||
"""Cache-first LLM-fallback CSV parse.
|
||||
|
||||
On cache hit, applies the stored mapping deterministically and
|
||||
increments ``use_count``. On cache miss, calls the LLM, validates
|
||||
the returned mapping against the first data row, and persists a
|
||||
new ``CsvFormatTemplate``. Raises ``LLMParseError`` on any
|
||||
failure; the caller (route layer) maps that to a 400."""
|
||||
if len(raw) > _MAX_CSV_BYTES:
|
||||
raise LLMParseError("CSV too large (1 MB max)")
|
||||
if not raw or not raw.strip():
|
||||
raise LLMParseError("empty CSV")
|
||||
|
||||
delimiter, preamble_rows = _detect_dialect(raw)
|
||||
text = _decode_raw(raw)
|
||||
|
||||
reader = csv.reader(io.StringIO(text), delimiter=delimiter)
|
||||
rows = list(reader)
|
||||
if preamble_rows >= len(rows):
|
||||
raise LLMParseError("no header row found in CSV")
|
||||
headers = [c.strip() for c in rows[preamble_rows]]
|
||||
data_rows = rows[preamble_rows + 1:]
|
||||
if not headers:
|
||||
raise LLMParseError("empty header row")
|
||||
|
||||
first_data_row = next(
|
||||
(r for r in data_rows if any(c.strip() for c in r)), None,
|
||||
)
|
||||
if first_data_row is None:
|
||||
raise LLMParseError("CSV contains a header but no data rows")
|
||||
|
||||
fp = _fingerprint(headers)
|
||||
existing = (await session.execute(
|
||||
select(CsvFormatTemplate).where(CsvFormatTemplate.fingerprint == fp)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
log.info("csv.format.cache_hit", fingerprint=fp,
|
||||
broker_label=existing.broker_label, use_count=existing.use_count)
|
||||
pie = _apply_mapping(headers, data_rows, existing.mapping)
|
||||
if not pie.positions:
|
||||
raise LLMParseError(
|
||||
"cached mapping produced no positions — the broker may have "
|
||||
"changed their CSV shape; ask the operator to evict the "
|
||||
"stale template"
|
||||
)
|
||||
existing.use_count += 1
|
||||
existing.last_used_at = utcnow()
|
||||
await session.commit()
|
||||
return pie
|
||||
|
||||
log.info("csv.format.cache_miss", fingerprint=fp,
|
||||
header_count=len(headers))
|
||||
samples = [r for r in data_rows[:_LLM_SAMPLES] if any(c.strip() for c in r)]
|
||||
async with httpx.AsyncClient(follow_redirects=True, timeout=30) as client:
|
||||
mapping, llm_log = await _extract_mapping_via_llm(client, headers, samples)
|
||||
_validate_mapping(mapping, headers, first_data_row)
|
||||
|
||||
pie = _apply_mapping(headers, data_rows, mapping)
|
||||
if not pie.positions:
|
||||
raise LLMParseError(
|
||||
"LLM mapping validated but produced no positions — the file "
|
||||
"may not contain portfolio data"
|
||||
)
|
||||
|
||||
now = utcnow()
|
||||
session.add(CsvFormatTemplate(
|
||||
fingerprint=fp,
|
||||
headers=headers,
|
||||
sample_row=first_data_row,
|
||||
mapping=mapping,
|
||||
preamble_rows=preamble_rows,
|
||||
delimiter=delimiter,
|
||||
broker_label=mapping.get("broker_label"),
|
||||
first_seen_at=now,
|
||||
last_used_at=now,
|
||||
use_count=1,
|
||||
llm_model=llm_log.model,
|
||||
llm_cost_usd=llm_log.cost_usd,
|
||||
))
|
||||
await session.commit()
|
||||
return pie
|
||||
|
|
|
|||
|
|
@ -4,6 +4,27 @@ from __future__ import annotations
|
|||
import pytest
|
||||
|
||||
|
||||
def _build_session_factory(tmp_path):
|
||||
"""Spin up a fresh in-memory schema and return (engine, factory).
|
||||
Matches the pattern used in tests/test_referral_conversion.py."""
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from app import db as db_mod
|
||||
from app.db import Base
|
||||
import app.models # noqa: F401 — registers models on Base.metadata
|
||||
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/csv.db")
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
db_mod._engine = engine
|
||||
db_mod._session_factory = factory
|
||||
|
||||
async def _setup():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
return engine, factory, _setup
|
||||
|
||||
|
||||
def test_csv_format_template_model_columns():
|
||||
"""Model exposes every column the spec requires, with correct types."""
|
||||
from sqlalchemy import inspect
|
||||
|
|
@ -290,3 +311,148 @@ async def test_extract_mapping_via_llm_provider_failure_wraps():
|
|||
|
||||
with pytest.raises(LLMParseError, match="provider"):
|
||||
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_with_llm_cache_miss_inserts_template(tmp_path):
|
||||
from unittest.mock import AsyncMock
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models import CsvFormatTemplate
|
||||
from app.services.llm_csv_parser import parse_with_llm
|
||||
from app.services.openrouter import LogResult
|
||||
|
||||
_, factory, setup = _build_session_factory(tmp_path)
|
||||
await setup()
|
||||
|
||||
raw = (
|
||||
b"Symbol,Quantity,Avg Price,Currency\n"
|
||||
b"AAPL,100,150.25,USD\n"
|
||||
b"MSFT,50,310.00,USD\n"
|
||||
)
|
||||
|
||||
import app.services.llm_csv_parser as mod
|
||||
mod.call_llm = AsyncMock(return_value=LogResult(
|
||||
content='{"ticker_col":"Symbol","qty_col":"Quantity",'
|
||||
'"cost_col":"Avg Price","currency_col":"Currency",'
|
||||
'"name_col":null,"broker_label":"Generic broker"}',
|
||||
model="deepseek/deepseek-v4-flash",
|
||||
prompt_tokens=120, completion_tokens=40, cost_usd=0.0002,
|
||||
))
|
||||
|
||||
async with factory() as session:
|
||||
pie = await parse_with_llm(raw, session)
|
||||
|
||||
assert len(pie.positions) == 2
|
||||
assert pie.positions[0].slice == "AAPL"
|
||||
|
||||
async with factory() as session:
|
||||
rows = (await session.execute(select(CsvFormatTemplate))).scalars().all()
|
||||
assert len(rows) == 1
|
||||
tmpl = rows[0]
|
||||
assert tmpl.headers == ["Symbol", "Quantity", "Avg Price", "Currency"]
|
||||
assert tmpl.sample_row == ["AAPL", "100", "150.25", "USD"]
|
||||
assert tmpl.mapping["ticker_col"] == "Symbol"
|
||||
assert tmpl.broker_label == "Generic broker"
|
||||
assert tmpl.use_count == 1
|
||||
assert tmpl.llm_cost_usd == pytest.approx(0.0002)
|
||||
# The crucial PII guarantee:
|
||||
assert not hasattr(tmpl, "user_id"), "sample row must not be linked to a user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_with_llm_cache_hit_skips_llm(tmp_path):
|
||||
from unittest.mock import AsyncMock
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import utcnow
|
||||
from app.models import CsvFormatTemplate
|
||||
from app.services.llm_csv_parser import _fingerprint, parse_with_llm
|
||||
|
||||
_, factory, setup = _build_session_factory(tmp_path)
|
||||
await setup()
|
||||
|
||||
headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
|
||||
fp = _fingerprint(headers)
|
||||
|
||||
# Pre-populate a cache hit row.
|
||||
async with factory() as session:
|
||||
session.add(CsvFormatTemplate(
|
||||
fingerprint=fp,
|
||||
headers=headers,
|
||||
sample_row=["AAPL", "100", "150.25", "USD"],
|
||||
mapping={
|
||||
"ticker_col": "Symbol", "qty_col": "Quantity",
|
||||
"cost_col": "Avg Price", "currency_col": "Currency",
|
||||
"name_col": None,
|
||||
},
|
||||
preamble_rows=0,
|
||||
delimiter=",",
|
||||
broker_label="Cached broker",
|
||||
first_seen_at=utcnow(),
|
||||
last_used_at=utcnow(),
|
||||
use_count=1,
|
||||
llm_model="seed",
|
||||
llm_cost_usd=0.0,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
raw = (
|
||||
b"Symbol,Quantity,Avg Price,Currency\n"
|
||||
b"NVDA,40,425.50,USD\n"
|
||||
)
|
||||
|
||||
import app.services.llm_csv_parser as mod
|
||||
mod.call_llm = AsyncMock(side_effect=AssertionError("call_llm must NOT be called on cache hit"))
|
||||
|
||||
async with factory() as session:
|
||||
pie = await parse_with_llm(raw, session)
|
||||
|
||||
assert pie.positions[0].slice == "NVDA"
|
||||
|
||||
async with factory() as session:
|
||||
rows = (await session.execute(select(CsvFormatTemplate))).scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].use_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_with_llm_stale_mapping_raises_but_does_not_evict(tmp_path):
|
||||
from unittest.mock import AsyncMock
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import utcnow
|
||||
from app.models import CsvFormatTemplate
|
||||
from app.services.llm_csv_parser import LLMParseError, _fingerprint, parse_with_llm
|
||||
|
||||
_, factory, setup = _build_session_factory(tmp_path)
|
||||
await setup()
|
||||
|
||||
headers = ["Symbol", "Quantity"]
|
||||
fp = _fingerprint(headers)
|
||||
# Cached mapping says qty is in column "Symbol" — clearly wrong; will
|
||||
# never produce a parseable row.
|
||||
async with factory() as session:
|
||||
session.add(CsvFormatTemplate(
|
||||
fingerprint=fp, headers=headers,
|
||||
sample_row=["AAPL", "100"],
|
||||
mapping={"ticker_col": "Symbol", "qty_col": "Symbol"},
|
||||
preamble_rows=0, delimiter=",", broker_label=None,
|
||||
first_seen_at=utcnow(), last_used_at=utcnow(), use_count=1,
|
||||
llm_model="seed", llm_cost_usd=0.0,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
raw = b"Symbol,Quantity\nAAPL,100\nMSFT,50\n"
|
||||
|
||||
import app.services.llm_csv_parser as mod
|
||||
mod.call_llm = AsyncMock(side_effect=AssertionError("must not be called"))
|
||||
|
||||
async with factory() as session:
|
||||
with pytest.raises(LLMParseError):
|
||||
await parse_with_llm(raw, session)
|
||||
|
||||
# Stale template must NOT have been auto-deleted (operator owns eviction).
|
||||
async with factory() as session:
|
||||
rows = (await session.execute(select(CsvFormatTemplate))).scalars().all()
|
||||
assert len(rows) == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue