From c77b3564f3bf4d3860ec29f9a19b89612b6c1983 Mon Sep 17 00:00:00 2001 From: Giorgio Gilestro Date: Wed, 27 May 2026 12:21:19 +0200 Subject: [PATCH] csv-parser: add _extract_mapping_via_llm with provider-failure wrapping Co-Authored-By: Claude Opus 4.7 --- app/services/llm_csv_parser.py | 78 ++++++++++++++++++++++++++++++++++ tests/test_llm_csv_parser.py | 70 ++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/app/services/llm_csv_parser.py b/app/services/llm_csv_parser.py index 44cfad1..875d863 100644 --- a/app/services/llm_csv_parser.py +++ b/app/services/llm_csv_parser.py @@ -20,8 +20,12 @@ from __future__ import annotations import csv import hashlib import io +import json + +import httpx from app.services.csv_import import CSVImportError, ParsedPie, ParsedPosition +from app.services.openrouter import LogResult, call_llm # --------------------------------------------------------------------------- # Module-level constants @@ -31,11 +35,41 @@ from app.services.csv_import import CSVImportError, ParsedPie, ParsedPosition # Real broker preambles are typically 1-10 lines. _MAX_PREAMBLE_SCAN = 30 +# Number of sample rows to send to the LLM and max token budget for the reply. +_LLM_SAMPLES = 5 +_LLM_MAX_TOKENS = 400 + # Required and optional keys in the LLM-returned column mapping. _REQUIRED_MAPPING_KEYS = ("ticker_col", "qty_col") _OPTIONAL_MAPPING_KEYS = ("name_col", "cost_col", "currency_col") +_SYSTEM_PROMPT = """\ +You are an expert at recognising broker portfolio CSV formats. + +You will be given the header row and 3-5 sample data rows from a CSV. +Identify which column contains each field. Return ONLY a single JSON +object, no prose, no markdown fences. + +Schema (use the EXACT header string from the input; use null if no +column is a good match): + +{ + "ticker_col": "
", + "qty_col": "
", + "name_col": "
", + "cost_col": "
", // average price per share + "currency_col": "
", + "broker_label": "" +} + +Rules: +- ticker_col and qty_col are required. If either is missing, return all nulls. +- Use the EXACT header string as it appears in the input — do not paraphrase. +- Output JSON ONLY. No prose, no code fences. +""" + + class LLMParseError(CSVImportError): """Raised when the LLM call fails or returns an unusable mapping. @@ -251,3 +285,47 @@ def _apply_mapping( value=None, result=None, ) + + +def _build_user_prompt(headers: list[str], samples: list[list[str]]) -> str: + lines = ["headers: " + json.dumps(headers)] + lines.append("samples:") + for s in samples[:_LLM_SAMPLES]: + lines.append(" " + ",".join(s)) + return "\n".join(lines) + + +async def _extract_mapping_via_llm( + client: httpx.AsyncClient, + headers: list[str], + samples: list[list[str]], +) -> tuple[dict, LogResult]: + """Single LLM call returning ``(mapping_dict, LogResult)``. + + The LLM is asked for a strict JSON object (no markdown). We attempt + to parse the returned content; ``LLMParseError`` wraps any failure + in a way callers can surface to the user.""" + messages = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": _build_user_prompt(headers, samples)}, + ] + try: + result = await call_llm(client, messages, max_tokens=_LLM_MAX_TOKENS) + except Exception as e: + raise LLMParseError(f"LLM provider failed: {e}") from e + + content = (result.content or "").strip() + # Strip code fences if the model added them despite instructions. + if content.startswith("```"): + content = content.strip("`") + # Drop optional 'json' language tag. + if content.lstrip().lower().startswith("json"): + content = content.lstrip()[4:] + content = content.strip() + try: + mapping = json.loads(content) + except json.JSONDecodeError as e: + raise LLMParseError(f"LLM did not return valid JSON: {e}") from e + if not isinstance(mapping, dict): + raise LLMParseError("LLM JSON was not an object") + return mapping, result diff --git a/tests/test_llm_csv_parser.py b/tests/test_llm_csv_parser.py index c8a55cc..42eee89 100644 --- a/tests/test_llm_csv_parser.py +++ b/tests/test_llm_csv_parser.py @@ -220,3 +220,73 @@ def test_apply_mapping_skips_blank_and_unparseable_rows(): pie = _apply_mapping(headers, data_rows, mapping) assert [p.slice for p in pie.positions] == ["AAPL", "NVDA"] + + +@pytest.mark.asyncio +async def test_extract_mapping_via_llm_parses_valid_json(): + from unittest.mock import AsyncMock, MagicMock + from app.services.llm_csv_parser import _extract_mapping_via_llm + from app.services.openrouter import LogResult + + fake_result = LogResult( + content='{"ticker_col": "Symbol", "qty_col": "Quantity", ' + '"cost_col": "Avg Price", "currency_col": "Currency", ' + '"name_col": null, "broker_label": "IBKR Activity Statement"}', + model="deepseek/deepseek-v4-flash", + prompt_tokens=100, + completion_tokens=50, + cost_usd=0.0001, + ) + fake_client = MagicMock() + fake_call_llm = AsyncMock(return_value=fake_result) + + import app.services.llm_csv_parser as mod + mod.call_llm = fake_call_llm # monkeypatch + + headers = ["Symbol", "Quantity", "Avg Price", "Currency"] + samples = [["AAPL", "100", "150.25", "USD"]] + mapping, log = await _extract_mapping_via_llm(fake_client, headers, samples) + + assert mapping["ticker_col"] == "Symbol" + assert mapping["qty_col"] == "Quantity" + assert mapping["broker_label"] == "IBKR Activity Statement" + assert log.model == "deepseek/deepseek-v4-flash" + fake_call_llm.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_extract_mapping_via_llm_malformed_json_raises(): + from unittest.mock import AsyncMock, MagicMock + from app.services.llm_csv_parser import LLMParseError, _extract_mapping_via_llm + from app.services.openrouter import LogResult + + fake_result = LogResult( + content="Sure thing — here is the mapping! ticker=Symbol", + model="deepseek/deepseek-v4-flash", + prompt_tokens=10, + completion_tokens=20, + cost_usd=0.00005, + ) + fake_client = MagicMock() + fake_call_llm = AsyncMock(return_value=fake_result) + + import app.services.llm_csv_parser as mod + mod.call_llm = fake_call_llm + + with pytest.raises(LLMParseError, match="JSON"): + await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]]) + + +@pytest.mark.asyncio +async def test_extract_mapping_via_llm_provider_failure_wraps(): + from unittest.mock import AsyncMock, MagicMock + from app.services.llm_csv_parser import LLMParseError, _extract_mapping_via_llm + + fake_client = MagicMock() + fake_call_llm = AsyncMock(side_effect=RuntimeError("provider down")) + + import app.services.llm_csv_parser as mod + mod.call_llm = fake_call_llm + + with pytest.raises(LLMParseError, match="provider"): + await _extract_mapping_via_llm(fake_client, ["Symbol"], [["AAPL"]])