"""Chat endpoint — POST /api/chat. Grounded on the latest strategic log, current market quotes, and thesis-filtered headlines. Ephemeral: the conversation lives in the client; this endpoint just records each call's cost in `ai_calls`. """ from __future__ import annotations from collections import defaultdict from datetime import timedelta import httpx from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field from sqlalchemy import desc, func, select from sqlalchemy.ext.asyncio import AsyncSession from app.auth import require_token, maybe_current_user, CurrentUser from app.config import get_settings from app.db import get_session, utcnow from app.jobs._market_context import REFERENCE_LINE from app.models import AICall, Headline, Quote, StrategicLog from app.routers.api import _md_to_html from app.services.llm_prompts import build_chat_system_prompt from app.services.openrouter import call_llm, month_start router = APIRouter(dependencies=[Depends(require_token)]) # --------------------------------------------------------------------------- # Pydantic models # --------------------------------------------------------------------------- class ChatMessage(BaseModel): role: str = Field(pattern="^(user|assistant)$") content: str class ChatRequest(BaseModel): messages: list[ChatMessage] # --------------------------------------------------------------------------- # Private helpers # --------------------------------------------------------------------------- THESIS_KEYWORDS_FALLBACK = [ "hormuz", "iran", "opec", "brent", "wti", "crude", "oil", "china", "taiwan", "yuan", "fed", "inflation", "cpi", "yield", "gold", "dollar", "yen", "saudi", "russia", "ukraine", "israel", "nato", "defence", "defense", ] async def _latest_quotes_by_group_chat(session: AsyncSession) -> dict[str, list[dict]]: sub = ( select(Quote.group_name, Quote.symbol, func.max(Quote.fetched_at).label("mx")) .group_by(Quote.group_name, Quote.symbol) .subquery() ) rows = (await session.execute( select(Quote).join( sub, (Quote.group_name == sub.c.group_name) & (Quote.symbol == sub.c.symbol) & (Quote.fetched_at == sub.c.mx), ).order_by(Quote.group_name, Quote.symbol) )).scalars().all() by_group: dict[str, list[dict]] = defaultdict(list) for q in rows: by_group[q.group_name].append({ "symbol": q.symbol, "label": q.label, "price": q.price, "currency": q.currency, "as_of": q.as_of, "changes": q.changes, }) return by_group async def _thesis_headlines_for_chat(session: AsyncSession, limit: int = 50) -> list[dict]: cutoff = utcnow() - timedelta(hours=24) rows = (await session.execute( select(Headline) .where(Headline.published_at >= cutoff) .order_by(desc(Headline.published_at)) .limit(300) )).scalars().all() out = [] for h in rows: if any(kw in h.title.lower() for kw in THESIS_KEYWORDS_FALLBACK): out.append({"source": h.source, "title": h.title}) if len(out) >= limit: break return out async def _month_spend(session: AsyncSession) -> float: total = (await session.execute( select(func.coalesce(func.sum(AICall.cost_usd), 0.0)) .where(AICall.called_at >= month_start()) )).scalar() return float(total or 0.0) # --------------------------------------------------------------------------- # Route # --------------------------------------------------------------------------- @router.post("/chat") async def chat( body: ChatRequest, session: AsyncSession = Depends(get_session), principal: CurrentUser | None = Depends(maybe_current_user), ): """Answer one user turn given the conversation so far. Grounded on the latest strategic log + market data + thesis-filtered headlines. Ephemeral — the conversation lives entirely in the client; the endpoint just records each call's cost in `ai_calls`.""" # Paid-only feature. Free users get the static log but not the # interactive chat (see /pricing). from app.services.access import is_paid_active if not is_paid_active(principal): raise HTTPException( status_code=402, detail={"code": "paid_required", "message": "Follow-up chat is a paid-tier feature."}, ) s = get_settings() if not s.OPENROUTER_API_KEY: raise HTTPException(status_code=503, detail="OPENROUTER_API_KEY not set") # Monthly cost cap — same one the log job respects. spent = await _month_spend(session) if spent >= s.OPENROUTER_MONTHLY_CAP_USD: raise HTTPException( status_code=429, detail=f"Monthly OpenRouter cap reached (${spent:.2f})", ) # Trim runaway conversations: keep last 20 turns. history = body.messages[-20:] if not history or history[-1].role != "user": raise HTTPException(status_code=400, detail="Last message must be user") # Gather grounding context. log_row = (await session.execute( select(StrategicLog).order_by(desc(StrategicLog.generated_at)).limit(1) )).scalar_one_or_none() quotes = await _latest_quotes_by_group_chat(session) headlines = await _thesis_headlines_for_chat(session) system_prompt = build_chat_system_prompt( s.CASSANDRA_TONE, s.CASSANDRA_ANALYSIS, log_content=log_row.content if log_row else None, log_generated_at=log_row.generated_at if log_row else None, quotes_by_group=quotes, headlines=headlines, reference_line=REFERENCE_LINE, ) msgs = [{"role": "system", "content": system_prompt}] for m in history: msgs.append({"role": m.role, "content": m.content}) try: async with httpx.AsyncClient(follow_redirects=True) as client: result = await call_llm(client, msgs) except Exception as e: session.add(AICall( model=s.OPENROUTER_MODEL, status="error", error=str(e)[:500], )) await session.commit() raise HTTPException(status_code=502, detail=f"OpenRouter error: {e}") session.add(AICall( model=result.model, prompt_tokens=result.prompt_tokens, completion_tokens=result.completion_tokens, cost_usd=result.cost_usd, status="ok", )) await session.commit() return { "role": "assistant", "content": result.content, "content_html": _md_to_html(result.content), "prompt_tokens": result.prompt_tokens, "completion_tokens": result.completion_tokens, }