analyse: thread user.lang into the system prompt
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
e4982cdc04
commit
d318039ad5
3 changed files with 89 additions and 4 deletions
|
|
@ -36,7 +36,7 @@ from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy import and_, func, select
|
from sqlalchemy import and_, func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.auth import require_auth
|
from app.auth import CurrentUser, require_auth
|
||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
from app.db import get_session, utcnow
|
from app.db import get_session, utcnow
|
||||||
from app.logging import get_logger
|
from app.logging import get_logger
|
||||||
|
|
@ -341,10 +341,11 @@ async def parse_portfolio(
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.post("/analyze", dependencies=[Depends(require_paid)])
|
@router.post("/analyze")
|
||||||
async def analyze_portfolio(
|
async def analyze_portfolio(
|
||||||
request: Request,
|
request: Request,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
principal: CurrentUser = Depends(require_paid),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Generate AI commentary for the supplied pie. The pie is held in
|
"""Generate AI commentary for the supplied pie. The pie is held in
|
||||||
memory only for the duration of the LLM call; nothing about holdings
|
memory only for the duration of the LLM call; nothing about holdings
|
||||||
|
|
@ -364,6 +365,11 @@ async def analyze_portfolio(
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=400, detail="malformed JSON body")
|
raise HTTPException(status_code=400, detail="malformed JSON body")
|
||||||
|
|
||||||
|
user_lang = (
|
||||||
|
principal.user.lang if (principal.user and principal.user.lang) else "en"
|
||||||
|
)
|
||||||
|
payload["lang"] = user_lang
|
||||||
|
|
||||||
try:
|
try:
|
||||||
req = portfolio_analysis.parse_request(payload)
|
req = portfolio_analysis.parse_request(payload)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from app.config import get_settings
|
||||||
from app.db import utcnow
|
from app.db import utcnow
|
||||||
from app.logging import get_logger
|
from app.logging import get_logger
|
||||||
from app.models import AICall
|
from app.models import AICall
|
||||||
|
from app.services.i18n import LANGUAGES, respond_in_clause
|
||||||
from app.services.openrouter import (
|
from app.services.openrouter import (
|
||||||
LogResult,
|
LogResult,
|
||||||
active_model,
|
active_model,
|
||||||
|
|
@ -74,6 +75,7 @@ class AnalysisRequest:
|
||||||
anchor: str | None = None
|
anchor: str | None = None
|
||||||
tone: str = "INTERMEDIATE" # NOVICE | INTERMEDIATE | PRO
|
tone: str = "INTERMEDIATE" # NOVICE | INTERMEDIATE | PRO
|
||||||
analysis: str = "SPECULATIVE" # DRY | SPECULATIVE
|
analysis: str = "SPECULATIVE" # DRY | SPECULATIVE
|
||||||
|
lang: str = "en"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -163,10 +165,13 @@ def parse_request(payload: dict) -> AnalysisRequest:
|
||||||
anchor = _sanitise_text(payload.get("anchor") or "", 32) or None
|
anchor = _sanitise_text(payload.get("anchor") or "", 32) or None
|
||||||
tone = _sanitise_text(payload.get("tone", "INTERMEDIATE"), 16) or "INTERMEDIATE"
|
tone = _sanitise_text(payload.get("tone", "INTERMEDIATE"), 16) or "INTERMEDIATE"
|
||||||
analysis = _sanitise_text(payload.get("analysis", "SPECULATIVE"), 16) or "SPECULATIVE"
|
analysis = _sanitise_text(payload.get("analysis", "SPECULATIVE"), 16) or "SPECULATIVE"
|
||||||
|
lang = (payload.get("lang") or "en").strip().lower()
|
||||||
|
if lang not in LANGUAGES:
|
||||||
|
lang = "en"
|
||||||
|
|
||||||
return AnalysisRequest(
|
return AnalysisRequest(
|
||||||
positions=positions, prices=prices, base_currency=base_currency,
|
positions=positions, prices=prices, base_currency=base_currency,
|
||||||
anchor=anchor, tone=tone, analysis=analysis,
|
anchor=anchor, tone=tone, analysis=analysis, lang=lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -276,7 +281,7 @@ def build_prompt(req: AnalysisRequest) -> tuple[str, str]:
|
||||||
head = enriched[:MAX_POSITIONS_INLINED]
|
head = enriched[:MAX_POSITIONS_INLINED]
|
||||||
tail_count = max(0, len(enriched) - MAX_POSITIONS_INLINED)
|
tail_count = max(0, len(enriched) - MAX_POSITIONS_INLINED)
|
||||||
|
|
||||||
system = build_system_prompt(req.tone, req.analysis) + "\n\n" + _SYSTEM_OVERRIDES
|
system = build_system_prompt(req.tone, req.analysis) + "\n\n" + _SYSTEM_OVERRIDES + respond_in_clause(req.lang)
|
||||||
|
|
||||||
user_parts = [
|
user_parts = [
|
||||||
f"# Portfolio commentary request — {utcnow().strftime('%Y-%m-%d')}",
|
f"# Portfolio commentary request — {utcnow().strftime('%Y-%m-%d')}",
|
||||||
|
|
|
||||||
|
|
@ -176,3 +176,77 @@ async def test_log_translation_fanout_per_language_failure_isolated(tmp_path, mo
|
||||||
async with factory() as session:
|
async with factory() as session:
|
||||||
rows = (await session.execute(select(StrategicLogTranslation))).scalars().all()
|
rows = (await session.execute(select(StrategicLogTranslation))).scalars().all()
|
||||||
assert rows == []
|
assert rows == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyse_threads_lang_into_system_prompt(tmp_path, monkeypatch):
|
||||||
|
"""When lang='it', the system prompt sent to call_llm contains
|
||||||
|
'Respond in Italian.' — the LLM does the rest."""
|
||||||
|
from app.services import portfolio_analysis as pa
|
||||||
|
from app.services.openrouter import LogResult
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def _fake_call_llm(client, messages, **kw):
|
||||||
|
captured["messages"] = messages
|
||||||
|
return LogResult(
|
||||||
|
content="Analisi del portafoglio in italiano.",
|
||||||
|
model="m", prompt_tokens=400, completion_tokens=100, cost_usd=0.0001,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
|
||||||
|
|
||||||
|
_, factory, setup = _build_session_factory(tmp_path)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
|
||||||
|
"currency": "USD", "name": "Apple Inc"}],
|
||||||
|
"prices": {"AAPL": {"p": 172.4, "c": "USD"}},
|
||||||
|
"fx": {"USD": 1.0},
|
||||||
|
"base_currency": "USD",
|
||||||
|
"tone": "INTERMEDIATE",
|
||||||
|
"analysis": "NORMAL",
|
||||||
|
"lang": "it",
|
||||||
|
}
|
||||||
|
req = pa.parse_request(payload)
|
||||||
|
assert req.lang == "it"
|
||||||
|
|
||||||
|
async with factory() as session:
|
||||||
|
await pa.analyse(session, req)
|
||||||
|
system = next(m["content"] for m in captured["messages"] if m["role"] == "system")
|
||||||
|
assert "Respond in Italian" in system
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyse_no_clause_when_lang_is_en(tmp_path, monkeypatch):
|
||||||
|
from app.services import portfolio_analysis as pa
|
||||||
|
from app.services.openrouter import LogResult
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def _fake_call_llm(client, messages, **kw):
|
||||||
|
captured["messages"] = messages
|
||||||
|
return LogResult(
|
||||||
|
content="Portfolio analysis in English.",
|
||||||
|
model="m", prompt_tokens=400, completion_tokens=100, cost_usd=0.0001,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(pa, "call_llm", _fake_call_llm)
|
||||||
|
|
||||||
|
_, factory, setup = _build_session_factory(tmp_path)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"positions": [{"yahoo_ticker": "AAPL", "qty": 10, "avg_cost": 150.0,
|
||||||
|
"currency": "USD", "name": "Apple Inc"}],
|
||||||
|
"prices": {"AAPL": {"p": 172.4, "c": "USD"}},
|
||||||
|
"fx": {"USD": 1.0},
|
||||||
|
"base_currency": "USD",
|
||||||
|
"tone": "INTERMEDIATE",
|
||||||
|
"analysis": "NORMAL",
|
||||||
|
"lang": "en",
|
||||||
|
}
|
||||||
|
req = pa.parse_request(payload)
|
||||||
|
async with factory() as session:
|
||||||
|
await pa.analyse(session, req)
|
||||||
|
system = next(m["content"] for m in captured["messages"] if m["role"] == "system")
|
||||||
|
assert "Respond in" not in system
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue