"""LLM transport layer — OpenRouter / DeepSeek API calls. Handles provider selection, retry + fallback machinery, and the monthly budget-cap helpers. Prompt engineering lives in ``app.services.llm_prompts``; this module only cares about *how* to reach the model, not *what to ask*. """ from __future__ import annotations import json from dataclasses import dataclass from datetime import datetime, timedelta, timezone import httpx from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from app import branding from app.config import get_settings OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" @dataclass class LogResult: content: str model: str prompt_tokens: int | None completion_tokens: int | None cost_usd: float | None def _provider_chain() -> list[str]: """Ordered list of providers to try: primary, then fallback (unless the fallback is unset, the same as primary, or has no API key).""" s = get_settings() primary = (s.LLM_PROVIDER or "deepseek").lower() fallback = (s.LLM_FALLBACK or "").lower() chain = [primary] if fallback and fallback != primary: chain.append(fallback) # Drop providers with no API key configured. return [p for p in chain if _provider_has_key(p)] def _provider_has_key(provider: str) -> bool: s = get_settings() if provider == "deepseek": return bool(s.DEEPSEEK_API_KEY) if provider == "openrouter": return bool(s.OPENROUTER_API_KEY) return False def _endpoint_for(provider: str) -> tuple[str, str, str, dict[str, str]]: """Resolve (url, api_key, default_model, extra_headers) for a specific provider. Raises if its API key isn't set.""" s = get_settings() if provider == "deepseek": if not s.DEEPSEEK_API_KEY: raise RuntimeError("DEEPSEEK_API_KEY not set") return s.DEEPSEEK_URL, s.DEEPSEEK_API_KEY, s.DEEPSEEK_MODEL, {} if provider == "openrouter": if not s.OPENROUTER_API_KEY: raise RuntimeError("OPENROUTER_API_KEY not set") return ( OPENROUTER_URL, s.OPENROUTER_API_KEY, s.OPENROUTER_MODEL, { # OpenRouter-specific attribution headers. Visible on the # OpenRouter dashboard — keep aligned with the live brand. "HTTP-Referer": branding.SITE_URL, "X-Title": branding.BRAND_NAME, # No-train opt-out. Tells OpenRouter (and any compatible # upstream) that this request must not be used to train # or improve models. The Privacy notice promises this; the # header is what makes the promise truthful. If a future # upstream ignores the header, fix the provider — not the # header — so the contract stays auditable. "X-OR-Allow-Training": "false", }, ) raise RuntimeError(f"Unknown LLM provider: {provider!r}") def llm_configured() -> bool: """At least one provider in the configured chain has an API key.""" return bool(_provider_chain()) def active_model() -> str: """Return the model name of the *first* provider in the configured chain (the one that would be tried first). Used to label AICall ledger rows when no actual call result is available yet.""" chain = _provider_chain() if not chain: return "unknown" s = get_settings() return s.DEEPSEEK_MODEL if chain[0] == "deepseek" else s.OPENROUTER_MODEL @retry( reraise=True, stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=2, max=30), retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.TransportError)), ) async def _call_provider( client: httpx.AsyncClient, provider: str, messages: list[dict], model: str | None, max_tokens: int, ) -> LogResult: """One provider call with tenacity retries on transport/HTTP errors. Lives inside the retry decorator so retries happen within a provider, not across the fallback chain.""" url, api_key, default_model, extra_headers = _endpoint_for(provider) used_model = model or default_model headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", **extra_headers, } r = await client.post( url, headers=headers, json={"model": used_model, "messages": messages, "max_tokens": max_tokens}, timeout=180, ) r.raise_for_status() data = r.json() msg = data["choices"][0]["message"] # Some providers return null content + populated `reasoning` for thinking # models, or null content when finish_reason=length cut off the response. content = msg.get("content") or msg.get("reasoning") if not content: finish = data["choices"][0].get("finish_reason") raise RuntimeError( f"LLM returned empty content (finish_reason={finish}, " f"provider={provider}, model={used_model}, max_tokens={max_tokens})" ) usage = data.get("usage") or {} return LogResult( content=content, # Record provider+model so admin can see which path produced this row. model=f"{provider}/{used_model}", prompt_tokens=usage.get("prompt_tokens"), completion_tokens=usage.get("completion_tokens"), cost_usd=usage.get("cost") or usage.get("total_cost"), ) async def call_llm( client: httpx.AsyncClient, messages: list[dict], model: str | None = None, max_tokens: int = 4000, ) -> LogResult: """Provider-aware chat completion with fallback. Tries primary (LLM_PROVIDER) first; if it raises after retries, falls through to LLM_FALLBACK. Raises only if every provider in the chain fails. The returned LogResult.model is prefixed with the provider that actually answered (e.g. ``deepseek/deepseek-v4-flash`` or ``openrouter/deepseek/deepseek-v4-flash``) — useful admin metadata even though we hide it from the user-facing UI.""" chain = _provider_chain() if not chain: raise RuntimeError("No LLM provider configured (no API key set)") last_exc: Exception | None = None for i, provider in enumerate(chain): try: result = await _call_provider( client, provider, messages, model, max_tokens, ) if i > 0: from app.logging import get_logger get_logger("llm").info( "llm.fallback_succeeded", provider=provider, attempt=i + 1, ) return result except Exception as e: last_exc = e if i + 1 < len(chain): from app.logging import get_logger get_logger("llm").warning( "llm.primary_failed_trying_fallback", provider=provider, error=str(e)[:200], ) continue # Re-raise the last exception so callers see the failure mode. assert last_exc is not None raise last_exc def month_window() -> tuple[datetime, datetime]: """[start, now] in UTC for the current calendar month.""" now = datetime.now(timezone.utc) start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) return start, now def month_start() -> datetime: return month_window()[0]