csv-parser: add _extract_mapping_via_llm with provider-failure wrapping
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
b99f46d2fc
commit
c77b3564f3
2 changed files with 148 additions and 0 deletions
|
|
@ -20,8 +20,12 @@ from __future__ import annotations
|
||||||
import csv
|
import csv
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from app.services.csv_import import CSVImportError, ParsedPie, ParsedPosition
|
from app.services.csv_import import CSVImportError, ParsedPie, ParsedPosition
|
||||||
|
from app.services.openrouter import LogResult, call_llm
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Module-level constants
|
# Module-level constants
|
||||||
|
|
@ -31,11 +35,41 @@ from app.services.csv_import import CSVImportError, ParsedPie, ParsedPosition
|
||||||
# Real broker preambles are typically 1-10 lines.
|
# Real broker preambles are typically 1-10 lines.
|
||||||
_MAX_PREAMBLE_SCAN = 30
|
_MAX_PREAMBLE_SCAN = 30
|
||||||
|
|
||||||
|
# Number of sample rows to send to the LLM and max token budget for the reply.
|
||||||
|
_LLM_SAMPLES = 5
|
||||||
|
_LLM_MAX_TOKENS = 400
|
||||||
|
|
||||||
# Required and optional keys in the LLM-returned column mapping.
|
# Required and optional keys in the LLM-returned column mapping.
|
||||||
_REQUIRED_MAPPING_KEYS = ("ticker_col", "qty_col")
|
_REQUIRED_MAPPING_KEYS = ("ticker_col", "qty_col")
|
||||||
_OPTIONAL_MAPPING_KEYS = ("name_col", "cost_col", "currency_col")
|
_OPTIONAL_MAPPING_KEYS = ("name_col", "cost_col", "currency_col")
|
||||||
|
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = """\
|
||||||
|
You are an expert at recognising broker portfolio CSV formats.
|
||||||
|
|
||||||
|
You will be given the header row and 3-5 sample data rows from a CSV.
|
||||||
|
Identify which column contains each field. Return ONLY a single JSON
|
||||||
|
object, no prose, no markdown fences.
|
||||||
|
|
||||||
|
Schema (use the EXACT header string from the input; use null if no
|
||||||
|
column is a good match):
|
||||||
|
|
||||||
|
{
|
||||||
|
"ticker_col": "<header name or null>",
|
||||||
|
"qty_col": "<header name or null>",
|
||||||
|
"name_col": "<header name or null>",
|
||||||
|
"cost_col": "<header name or null>", // average price per share
|
||||||
|
"currency_col": "<header name or null>",
|
||||||
|
"broker_label": "<short identifier like 'IBKR Activity Statement' or null>"
|
||||||
|
}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- ticker_col and qty_col are required. If either is missing, return all nulls.
|
||||||
|
- Use the EXACT header string as it appears in the input — do not paraphrase.
|
||||||
|
- Output JSON ONLY. No prose, no code fences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LLMParseError(CSVImportError):
|
class LLMParseError(CSVImportError):
|
||||||
"""Raised when the LLM call fails or returns an unusable mapping.
|
"""Raised when the LLM call fails or returns an unusable mapping.
|
||||||
|
|
||||||
|
|
@ -251,3 +285,47 @@ def _apply_mapping(
|
||||||
value=None,
|
value=None,
|
||||||
result=None,
|
result=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_user_prompt(headers: list[str], samples: list[list[str]]) -> str:
|
||||||
|
lines = ["headers: " + json.dumps(headers)]
|
||||||
|
lines.append("samples:")
|
||||||
|
for s in samples[:_LLM_SAMPLES]:
|
||||||
|
lines.append(" " + ",".join(s))
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_mapping_via_llm(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
headers: list[str],
|
||||||
|
samples: list[list[str]],
|
||||||
|
) -> tuple[dict, LogResult]:
|
||||||
|
"""Single LLM call returning ``(mapping_dict, LogResult)``.
|
||||||
|
|
||||||
|
The LLM is asked for a strict JSON object (no markdown). We attempt
|
||||||
|
to parse the returned content; ``LLMParseError`` wraps any failure
|
||||||
|
in a way callers can surface to the user."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": _build_user_prompt(headers, samples)},
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
result = await call_llm(client, messages, max_tokens=_LLM_MAX_TOKENS)
|
||||||
|
except Exception as e:
|
||||||
|
raise LLMParseError(f"LLM provider failed: {e}") from e
|
||||||
|
|
||||||
|
content = (result.content or "").strip()
|
||||||
|
# Strip code fences if the model added them despite instructions.
|
||||||
|
if content.startswith("```"):
|
||||||
|
content = content.strip("`")
|
||||||
|
# Drop optional 'json' language tag.
|
||||||
|
if content.lstrip().lower().startswith("json"):
|
||||||
|
content = content.lstrip()[4:]
|
||||||
|
content = content.strip()
|
||||||
|
try:
|
||||||
|
mapping = json.loads(content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise LLMParseError(f"LLM did not return valid JSON: {e}") from e
|
||||||
|
if not isinstance(mapping, dict):
|
||||||
|
raise LLMParseError("LLM JSON was not an object")
|
||||||
|
return mapping, result
|
||||||
|
|
|
||||||
|
|
@ -220,3 +220,73 @@ def test_apply_mapping_skips_blank_and_unparseable_rows():
|
||||||
|
|
||||||
pie = _apply_mapping(headers, data_rows, mapping)
|
pie = _apply_mapping(headers, data_rows, mapping)
|
||||||
assert [p.slice for p in pie.positions] == ["AAPL", "NVDA"]
|
assert [p.slice for p in pie.positions] == ["AAPL", "NVDA"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_mapping_via_llm_parses_valid_json():
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from app.services.llm_csv_parser import _extract_mapping_via_llm
|
||||||
|
from app.services.openrouter import LogResult
|
||||||
|
|
||||||
|
fake_result = LogResult(
|
||||||
|
content='{"ticker_col": "Symbol", "qty_col": "Quantity", '
|
||||||
|
'"cost_col": "Avg Price", "currency_col": "Currency", '
|
||||||
|
'"name_col": null, "broker_label": "IBKR Activity Statement"}',
|
||||||
|
model="deepseek/deepseek-v4-flash",
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
cost_usd=0.0001,
|
||||||
|
)
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_call_llm = AsyncMock(return_value=fake_result)
|
||||||
|
|
||||||
|
import app.services.llm_csv_parser as mod
|
||||||
|
mod.call_llm = fake_call_llm # monkeypatch
|
||||||
|
|
||||||
|
headers = ["Symbol", "Quantity", "Avg Price", "Currency"]
|
||||||
|
samples = [["AAPL", "100", "150.25", "USD"]]
|
||||||
|
mapping, log = await _extract_mapping_via_llm(fake_client, headers, samples)
|
||||||
|
|
||||||
|
assert mapping["ticker_col"] == "Symbol"
|
||||||
|
assert mapping["qty_col"] == "Quantity"
|
||||||
|
assert mapping["broker_label"] == "IBKR Activity Statement"
|
||||||
|
assert log.model == "deepseek/deepseek-v4-flash"
|
||||||
|
fake_call_llm.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_mapping_via_llm_malformed_json_raises():
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from app.services.llm_csv_parser import LLMParseError, _extract_mapping_via_llm
|
||||||
|
from app.services.openrouter import LogResult
|
||||||
|
|
||||||
|
fake_result = LogResult(
|
||||||
|
content="Sure thing — here is the mapping! ticker=Symbol",
|
||||||
|
model="deepseek/deepseek-v4-flash",
|
||||||
|
prompt_tokens=10,
|
||||||
|
completion_tokens=20,
|
||||||
|
cost_usd=0.00005,
|
||||||
|
)
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_call_llm = AsyncMock(return_value=fake_result)
|
||||||
|
|
||||||
|
import app.services.llm_csv_parser as mod
|
||||||
|
mod.call_llm = fake_call_llm
|
||||||
|
|
||||||
|
with pytest.raises(LLMParseError, match="JSON"):
|
||||||
|
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_mapping_via_llm_provider_failure_wraps():
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from app.services.llm_csv_parser import LLMParseError, _extract_mapping_via_llm
|
||||||
|
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_call_llm = AsyncMock(side_effect=RuntimeError("provider down"))
|
||||||
|
|
||||||
|
import app.services.llm_csv_parser as mod
|
||||||
|
mod.call_llm = fake_call_llm
|
||||||
|
|
||||||
|
with pytest.raises(LLMParseError, match="provider"):
|
||||||
|
await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue