diff --git a/app/main.py b/app/main.py index fe987f5..7f1729f 100644 --- a/app/main.py +++ b/app/main.py @@ -19,7 +19,9 @@ from app.db import get_session_factory from app.logging import configure_logging, get_logger from app.routers import api as api_router from app.routers import auth as auth_router +from app.routers import chat as chat_router from app.routers import email as email_router +from app.routers import ops as ops_router from app.routers import pages as pages_router from app.routers import polar_webhook as polar_webhook_router from app.routers import public as public_router @@ -89,6 +91,8 @@ app.mount( app.include_router(auth_router.router, tags=["auth"]) app.include_router(email_router.router, tags=["email"]) app.include_router(api_router.router, prefix="/api", tags=["api"]) +app.include_router(chat_router.router, prefix="/api", tags=["chat"]) +app.include_router(ops_router.router, prefix="/api", tags=["ops"]) app.include_router(universe_router.router, prefix="/api", tags=["universe"]) app.include_router(ticker_validate_router.router, prefix="/api", tags=["ticker-validate"]) app.include_router(sync_router.router, tags=["portfolio-sync"]) diff --git a/app/routers/api.py b/app/routers/api.py index 893d08f..5075654 100644 --- a/app/routers/api.py +++ b/app/routers/api.py @@ -10,45 +10,29 @@ import re from datetime import date, datetime, timedelta, timezone from typing import Literal -from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, Request, UploadFile -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi.responses import JSONResponse from sqlalchemy import desc, func, select from sqlalchemy.ext.asyncio import AsyncSession -from collections import defaultdict - -import httpx -from pydantic import BaseModel, Field +from pydantic import BaseModel from app.auth import require_token, maybe_current_user, CurrentUser from app.services.i18n import ACTIVE_LANGUAGES from app.config import get_settings from app.db import get_session, utcnow -from app.jobs._market_context import REFERENCE_LINE -from app.services.llm_prompts import ( - PROMPT_VERSION, - build_chat_system_prompt, -) -from app.services.openrouter import ( - call_llm, - month_start, -) from app.templates_env import templates from app.models import ( - AICall, Headline, IndicatorSummary, IndicatorSummaryTranslation, - JobRun, Quote, StrategicLog, StrategicLogTranslation, User, ) from app.schemas import ( - HealthOut, HeadlineOut, - JobStatus, QuoteOut, StrategicLogOut, ) @@ -56,11 +40,6 @@ from app.schemas import ( router = APIRouter(dependencies=[Depends(require_token)]) -JOB_NAMES = ("market_job", "news_job", "ai_log_job", "rollup_job", - "indicator_summary_job", "universe_flush_job", - "email_digest_job") -JOB_STALE_HOURS = 2.0 # job is "warn" if its last success was >2h ago - # Per-group expected freshness — bonds and intraday tape want daily data, # macro/economy/valuation are monthly/quarterly by nature. Older than this # many days from today → row gets a "stale" badge. @@ -565,10 +544,6 @@ async def log_days( return templates.TemplateResponse(request, "partials/calendar.html", payload) - -# --- Health / ops footer ----------------------------------------------------- - - # --- Aggregate summary + market status (dashboard header) ------------------- @@ -621,300 +596,6 @@ async def aggregate_summary( } -# Market → headline index mapping for the sticky bottom bar. Symbols must -# be present in config/default.toml so market_job populates `quotes`. -_MARKET_INDEX = { - "NYSE": ("^GSPC", "S&P 500"), - "LSE": ("^FTSE", "FTSE 100"), - # XETRA → Euro Stoxx 50 rather than ^GDAXI: Yahoo's DAX ticker is - # patchy via the chart endpoint, and ^STOXX50E is already tracked in - # config/default.toml's equity group. - "XETRA": ("^STOXX50E", "STOXX 50"), - "JPX": ("^N225", "Nikkei 225"), - "HKEX": ("^HSI", "Hang Seng"), - "SSE": ("000300.SS", "CSI 300"), -} - - -def _fmt_price(p: float | None) -> str: - if p is None: - return "—" - if abs(p) >= 1000: - return f"{p:,.0f}" - if abs(p) >= 100: - return f"{p:,.1f}" - return f"{p:,.2f}" - - -@router.get("/markets-bar", response_class=HTMLResponse, include_in_schema=False) -async def markets_bar( - request: Request, - session: AsyncSession = Depends(get_session), - as_: str | None = Query(default=None, alias="as"), -): - """The sticky bottom-bar payload: per-market open/close status with the - market's headline index price + 1d change. Refreshed by HTMX every 60s. - """ - from app.services.markets import all_statuses - - statuses = all_statuses() - # Latest quote per headline-index symbol in one query. - wanted_syms = [sym for sym, _ in _MARKET_INDEX.values()] - sub = ( - select(Quote.symbol, func.max(Quote.fetched_at).label("mx")) - .where(Quote.symbol.in_(wanted_syms)) - .group_by(Quote.symbol) - .subquery() - ) - rows = (await session.execute( - select(Quote).join( - sub, - (Quote.symbol == sub.c.symbol) & (Quote.fetched_at == sub.c.mx), - ) - )).scalars().all() - by_sym = {q.symbol: q for q in rows} - - markets: list[dict] = [] - for st in statuses: - sym, label = _MARKET_INDEX.get(st["code"], (None, None)) - q = by_sym.get(sym) if sym else None - idx = None - if q is not None and q.price is not None: - idx = { - "symbol": q.symbol, - "label": label, - "price_fmt": _fmt_price(q.price), - "change_1d_pct": (q.changes or {}).get("1d"), - } - markets.append({ - "code": st["code"], - "label": st["label"], - "open": st["open"], - "until_iso": st["until"].isoformat(), - "until_hhmm": st["until"].strftime("%H:%M"), - "index": idx, - }) - - return templates.TemplateResponse( - request, "partials/markets_bar.html", - {"markets": markets}, - ) - - -@router.get("/health", response_class=HTMLResponse, include_in_schema=False) -async def health_html( - request: Request, - session: AsyncSession = Depends(get_session), - as_: str | None = Query(default=None, alias="as"), -): - """Returns an HTML fragment by default (the ops footer); ?as=json returns the - structured object. The default is HTML because that's how the dashboard - consumes it; CLI/curl users will pass ?as=json.""" - try: - await session.execute(select(func.now())) - db_ok = True - except Exception: - db_ok = False - - now = utcnow() - jobs: list[dict] = [] - structured: list[JobStatus] = [] - for name in JOB_NAMES: - row = (await session.execute( - select(JobRun).where(JobRun.name == name) - .order_by(desc(JobRun.started_at)).limit(1) - )).scalar_one_or_none() - if row is None: - jobs.append({"name": name, "led": "idle", "age": "—", - "last_finished": None}) - structured.append(JobStatus(name=name)) - continue - if row.status == "success": - secs = _age_seconds(now, row.finished_at or row.started_at) or 0 - led = "ok" if secs < JOB_STALE_HOURS * 3600 else "warn" - elif row.status == "skipped": - led = "warn" - elif row.status == "running": - led = "warn" - else: - led = "err" - jobs.append({ - "name": name, "led": led, - "age": _fmt_age(now, row.finished_at or row.started_at), - "last_finished": row.finished_at, - }) - structured.append(JobStatus( - name=name, last_started=row.started_at, - last_finished=row.finished_at, status=row.status, - error=row.error, items_written=row.items_written, - )) - - if as_ == "json": - return JSONResponse( - HealthOut(db="ok" if db_ok else "down", jobs=structured).model_dump(mode="json") - ) - return templates.TemplateResponse( - request, "partials/ops_footer.html", - {"db_ok": db_ok, "jobs": jobs}, - ) - - -# --- Chat ------------------------------------------------------------------- - - -class ChatMessage(BaseModel): - role: str = Field(pattern="^(user|assistant)$") - content: str - - -class ChatRequest(BaseModel): - messages: list[ChatMessage] - - - -THESIS_KEYWORDS_FALLBACK = [ - "hormuz", "iran", "opec", "brent", "wti", "crude", "oil", - "china", "taiwan", "yuan", "fed", "inflation", "cpi", "yield", - "gold", "dollar", "yen", "saudi", "russia", "ukraine", "israel", - "nato", "defence", "defense", -] - - -async def _latest_quotes_by_group_chat(session: AsyncSession) -> dict[str, list[dict]]: - sub = ( - select(Quote.group_name, Quote.symbol, - func.max(Quote.fetched_at).label("mx")) - .group_by(Quote.group_name, Quote.symbol) - .subquery() - ) - rows = (await session.execute( - select(Quote).join( - sub, - (Quote.group_name == sub.c.group_name) - & (Quote.symbol == sub.c.symbol) - & (Quote.fetched_at == sub.c.mx), - ).order_by(Quote.group_name, Quote.symbol) - )).scalars().all() - by_group: dict[str, list[dict]] = defaultdict(list) - for q in rows: - by_group[q.group_name].append({ - "symbol": q.symbol, "label": q.label, - "price": q.price, "currency": q.currency, - "as_of": q.as_of, "changes": q.changes, - }) - return by_group - - -async def _thesis_headlines_for_chat(session: AsyncSession, limit: int = 50) -> list[dict]: - cutoff = utcnow() - timedelta(hours=24) - rows = (await session.execute( - select(Headline) - .where(Headline.published_at >= cutoff) - .order_by(desc(Headline.published_at)) - .limit(300) - )).scalars().all() - out = [] - for h in rows: - if any(kw in h.title.lower() for kw in THESIS_KEYWORDS_FALLBACK): - out.append({"source": h.source, "title": h.title}) - if len(out) >= limit: - break - return out - - -async def _month_spend(session: AsyncSession) -> float: - total = (await session.execute( - select(func.coalesce(func.sum(AICall.cost_usd), 0.0)) - .where(AICall.called_at >= month_start()) - )).scalar() - return float(total or 0.0) - - -@router.post("/chat") -async def chat( - body: ChatRequest, - session: AsyncSession = Depends(get_session), - principal: CurrentUser | None = Depends(maybe_current_user), -): - """Answer one user turn given the conversation so far. Grounded on the - latest strategic log + market data + thesis-filtered headlines. - Ephemeral — the conversation lives entirely in the client; the endpoint - just records each call's cost in `ai_calls`.""" - # Paid-only feature. Free users get the static log but not the - # interactive chat (see /pricing). - from app.services.access import is_paid_active - if not is_paid_active(principal): - raise HTTPException( - status_code=402, - detail={"code": "paid_required", - "message": "Follow-up chat is a paid-tier feature."}, - ) - - s = get_settings() - if not s.OPENROUTER_API_KEY: - raise HTTPException(status_code=503, detail="OPENROUTER_API_KEY not set") - - # Monthly cost cap — same one the log job respects. - spent = await _month_spend(session) - if spent >= s.OPENROUTER_MONTHLY_CAP_USD: - raise HTTPException( - status_code=429, - detail=f"Monthly OpenRouter cap reached (${spent:.2f})", - ) - - # Trim runaway conversations: keep last 20 turns. - history = body.messages[-20:] - if not history or history[-1].role != "user": - raise HTTPException(status_code=400, detail="Last message must be user") - - # Gather grounding context. - log_row = (await session.execute( - select(StrategicLog).order_by(desc(StrategicLog.generated_at)).limit(1) - )).scalar_one_or_none() - quotes = await _latest_quotes_by_group_chat(session) - headlines = await _thesis_headlines_for_chat(session) - - system_prompt = build_chat_system_prompt( - s.CASSANDRA_TONE, s.CASSANDRA_ANALYSIS, - log_content=log_row.content if log_row else None, - log_generated_at=log_row.generated_at if log_row else None, - quotes_by_group=quotes, - headlines=headlines, - reference_line=REFERENCE_LINE, - ) - - msgs = [{"role": "system", "content": system_prompt}] - for m in history: - msgs.append({"role": m.role, "content": m.content}) - - try: - async with httpx.AsyncClient(follow_redirects=True) as client: - result = await call_llm(client, msgs) - except Exception as e: - session.add(AICall( - model=s.OPENROUTER_MODEL, status="error", error=str(e)[:500], - )) - await session.commit() - raise HTTPException(status_code=502, detail=f"OpenRouter error: {e}") - - session.add(AICall( - model=result.model, - prompt_tokens=result.prompt_tokens, - completion_tokens=result.completion_tokens, - cost_usd=result.cost_usd, - status="ok", - )) - await session.commit() - - return { - "role": "assistant", - "content": result.content, - "content_html": _md_to_html(result.content), - "prompt_tokens": result.prompt_tokens, - "completion_tokens": result.completion_tokens, - } - - # --------------------------------------------------------------------------- # Settings — digest preferences # --------------------------------------------------------------------------- diff --git a/app/routers/chat.py b/app/routers/chat.py new file mode 100644 index 0000000..f4198ba --- /dev/null +++ b/app/routers/chat.py @@ -0,0 +1,193 @@ +"""Chat endpoint — POST /api/chat. + +Grounded on the latest strategic log, current market quotes, and +thesis-filtered headlines. Ephemeral: the conversation lives in the +client; this endpoint just records each call's cost in `ai_calls`. +""" +from __future__ import annotations + +from collections import defaultdict +from datetime import timedelta + +import httpx +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy import desc, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import require_token, maybe_current_user, CurrentUser +from app.config import get_settings +from app.db import get_session, utcnow +from app.jobs._market_context import REFERENCE_LINE +from app.models import AICall, Headline, Quote, StrategicLog +from app.routers.api import _md_to_html +from app.services.llm_prompts import build_chat_system_prompt +from app.services.openrouter import call_llm, month_start + +router = APIRouter(dependencies=[Depends(require_token)]) + + +# --------------------------------------------------------------------------- +# Pydantic models +# --------------------------------------------------------------------------- + + +class ChatMessage(BaseModel): + role: str = Field(pattern="^(user|assistant)$") + content: str + + +class ChatRequest(BaseModel): + messages: list[ChatMessage] + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + +THESIS_KEYWORDS_FALLBACK = [ + "hormuz", "iran", "opec", "brent", "wti", "crude", "oil", + "china", "taiwan", "yuan", "fed", "inflation", "cpi", "yield", + "gold", "dollar", "yen", "saudi", "russia", "ukraine", "israel", + "nato", "defence", "defense", +] + + +async def _latest_quotes_by_group_chat(session: AsyncSession) -> dict[str, list[dict]]: + sub = ( + select(Quote.group_name, Quote.symbol, + func.max(Quote.fetched_at).label("mx")) + .group_by(Quote.group_name, Quote.symbol) + .subquery() + ) + rows = (await session.execute( + select(Quote).join( + sub, + (Quote.group_name == sub.c.group_name) + & (Quote.symbol == sub.c.symbol) + & (Quote.fetched_at == sub.c.mx), + ).order_by(Quote.group_name, Quote.symbol) + )).scalars().all() + by_group: dict[str, list[dict]] = defaultdict(list) + for q in rows: + by_group[q.group_name].append({ + "symbol": q.symbol, "label": q.label, + "price": q.price, "currency": q.currency, + "as_of": q.as_of, "changes": q.changes, + }) + return by_group + + +async def _thesis_headlines_for_chat(session: AsyncSession, limit: int = 50) -> list[dict]: + cutoff = utcnow() - timedelta(hours=24) + rows = (await session.execute( + select(Headline) + .where(Headline.published_at >= cutoff) + .order_by(desc(Headline.published_at)) + .limit(300) + )).scalars().all() + out = [] + for h in rows: + if any(kw in h.title.lower() for kw in THESIS_KEYWORDS_FALLBACK): + out.append({"source": h.source, "title": h.title}) + if len(out) >= limit: + break + return out + + +async def _month_spend(session: AsyncSession) -> float: + total = (await session.execute( + select(func.coalesce(func.sum(AICall.cost_usd), 0.0)) + .where(AICall.called_at >= month_start()) + )).scalar() + return float(total or 0.0) + + +# --------------------------------------------------------------------------- +# Route +# --------------------------------------------------------------------------- + + +@router.post("/chat") +async def chat( + body: ChatRequest, + session: AsyncSession = Depends(get_session), + principal: CurrentUser | None = Depends(maybe_current_user), +): + """Answer one user turn given the conversation so far. Grounded on the + latest strategic log + market data + thesis-filtered headlines. + Ephemeral — the conversation lives entirely in the client; the endpoint + just records each call's cost in `ai_calls`.""" + # Paid-only feature. Free users get the static log but not the + # interactive chat (see /pricing). + from app.services.access import is_paid_active + if not is_paid_active(principal): + raise HTTPException( + status_code=402, + detail={"code": "paid_required", + "message": "Follow-up chat is a paid-tier feature."}, + ) + + s = get_settings() + if not s.OPENROUTER_API_KEY: + raise HTTPException(status_code=503, detail="OPENROUTER_API_KEY not set") + + # Monthly cost cap — same one the log job respects. + spent = await _month_spend(session) + if spent >= s.OPENROUTER_MONTHLY_CAP_USD: + raise HTTPException( + status_code=429, + detail=f"Monthly OpenRouter cap reached (${spent:.2f})", + ) + + # Trim runaway conversations: keep last 20 turns. + history = body.messages[-20:] + if not history or history[-1].role != "user": + raise HTTPException(status_code=400, detail="Last message must be user") + + # Gather grounding context. + log_row = (await session.execute( + select(StrategicLog).order_by(desc(StrategicLog.generated_at)).limit(1) + )).scalar_one_or_none() + quotes = await _latest_quotes_by_group_chat(session) + headlines = await _thesis_headlines_for_chat(session) + + system_prompt = build_chat_system_prompt( + s.CASSANDRA_TONE, s.CASSANDRA_ANALYSIS, + log_content=log_row.content if log_row else None, + log_generated_at=log_row.generated_at if log_row else None, + quotes_by_group=quotes, + headlines=headlines, + reference_line=REFERENCE_LINE, + ) + + msgs = [{"role": "system", "content": system_prompt}] + for m in history: + msgs.append({"role": m.role, "content": m.content}) + + try: + async with httpx.AsyncClient(follow_redirects=True) as client: + result = await call_llm(client, msgs) + except Exception as e: + session.add(AICall( + model=s.OPENROUTER_MODEL, status="error", error=str(e)[:500], + )) + await session.commit() + raise HTTPException(status_code=502, detail=f"OpenRouter error: {e}") + + session.add(AICall( + model=result.model, + prompt_tokens=result.prompt_tokens, + completion_tokens=result.completion_tokens, + cost_usd=result.cost_usd, + status="ok", + )) + await session.commit() + + return { + "role": "assistant", + "content": result.content, + "content_html": _md_to_html(result.content), + "prompt_tokens": result.prompt_tokens, + "completion_tokens": result.completion_tokens, + } diff --git a/app/routers/ops.py b/app/routers/ops.py new file mode 100644 index 0000000..289f803 --- /dev/null +++ b/app/routers/ops.py @@ -0,0 +1,162 @@ +"""HTML-only ops endpoints — /api/markets-bar and /api/health. + +These are HTMX partials consumed by the dashboard. They return HTML by +default (not JSON) and are not included in the OpenAPI schema. +""" +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import HTMLResponse, JSONResponse +from sqlalchemy import desc, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import require_token +from app.db import get_session, utcnow +from app.models import JobRun, Quote +from app.routers.api import _age_seconds, _fmt_age +from app.schemas import HealthOut, JobStatus +from app.templates_env import templates + +router = APIRouter(dependencies=[Depends(require_token)]) + +JOB_NAMES = ("market_job", "news_job", "ai_log_job", "rollup_job", + "indicator_summary_job", "universe_flush_job", + "email_digest_job") +JOB_STALE_HOURS = 2.0 # job is "warn" if its last success was >2h ago + +# Market → headline index mapping for the sticky bottom bar. Symbols must +# be present in config/default.toml so market_job populates `quotes`. +_MARKET_INDEX = { + "NYSE": ("^GSPC", "S&P 500"), + "LSE": ("^FTSE", "FTSE 100"), + # XETRA → Euro Stoxx 50 rather than ^GDAXI: Yahoo's DAX ticker is + # patchy via the chart endpoint, and ^STOXX50E is already tracked in + # config/default.toml's equity group. + "XETRA": ("^STOXX50E", "STOXX 50"), + "JPX": ("^N225", "Nikkei 225"), + "HKEX": ("^HSI", "Hang Seng"), + "SSE": ("000300.SS", "CSI 300"), +} + + +def _fmt_price(p: float | None) -> str: + if p is None: + return "—" + if abs(p) >= 1000: + return f"{p:,.0f}" + if abs(p) >= 100: + return f"{p:,.1f}" + return f"{p:,.2f}" + + +@router.get("/markets-bar", response_class=HTMLResponse, include_in_schema=False) +async def markets_bar( + request: Request, + session: AsyncSession = Depends(get_session), + as_: str | None = Query(default=None, alias="as"), +): + """The sticky bottom-bar payload: per-market open/close status with the + market's headline index price + 1d change. Refreshed by HTMX every 60s. + """ + from app.services.markets import all_statuses + + statuses = all_statuses() + # Latest quote per headline-index symbol in one query. + wanted_syms = [sym for sym, _ in _MARKET_INDEX.values()] + sub = ( + select(Quote.symbol, func.max(Quote.fetched_at).label("mx")) + .where(Quote.symbol.in_(wanted_syms)) + .group_by(Quote.symbol) + .subquery() + ) + rows = (await session.execute( + select(Quote).join( + sub, + (Quote.symbol == sub.c.symbol) & (Quote.fetched_at == sub.c.mx), + ) + )).scalars().all() + by_sym = {q.symbol: q for q in rows} + + markets: list[dict] = [] + for st in statuses: + sym, label = _MARKET_INDEX.get(st["code"], (None, None)) + q = by_sym.get(sym) if sym else None + idx = None + if q is not None and q.price is not None: + idx = { + "symbol": q.symbol, + "label": label, + "price_fmt": _fmt_price(q.price), + "change_1d_pct": (q.changes or {}).get("1d"), + } + markets.append({ + "code": st["code"], + "label": st["label"], + "open": st["open"], + "until_iso": st["until"].isoformat(), + "until_hhmm": st["until"].strftime("%H:%M"), + "index": idx, + }) + + return templates.TemplateResponse( + request, "partials/markets_bar.html", + {"markets": markets}, + ) + + +@router.get("/health", response_class=HTMLResponse, include_in_schema=False) +async def health_html( + request: Request, + session: AsyncSession = Depends(get_session), + as_: str | None = Query(default=None, alias="as"), +): + """Returns an HTML fragment by default (the ops footer); ?as=json returns the + structured object. The default is HTML because that's how the dashboard + consumes it; CLI/curl users will pass ?as=json.""" + try: + await session.execute(select(func.now())) + db_ok = True + except Exception: + db_ok = False + + now = utcnow() + jobs: list[dict] = [] + structured: list[JobStatus] = [] + for name in JOB_NAMES: + row = (await session.execute( + select(JobRun).where(JobRun.name == name) + .order_by(desc(JobRun.started_at)).limit(1) + )).scalar_one_or_none() + if row is None: + jobs.append({"name": name, "led": "idle", "age": "—", + "last_finished": None}) + structured.append(JobStatus(name=name)) + continue + if row.status == "success": + secs = _age_seconds(now, row.finished_at or row.started_at) or 0 + led = "ok" if secs < JOB_STALE_HOURS * 3600 else "warn" + elif row.status == "skipped": + led = "warn" + elif row.status == "running": + led = "warn" + else: + led = "err" + jobs.append({ + "name": name, "led": led, + "age": _fmt_age(now, row.finished_at or row.started_at), + "last_finished": row.finished_at, + }) + structured.append(JobStatus( + name=name, last_started=row.started_at, + last_finished=row.finished_at, status=row.status, + error=row.error, items_written=row.items_written, + )) + + if as_ == "json": + return JSONResponse( + HealthOut(db="ok" if db_ok else "down", jobs=structured).model_dump(mode="json") + ) + return templates.TemplateResponse( + request, "partials/ops_footer.html", + {"db_ok": db_ok, "jobs": jobs}, + ) diff --git a/tests/test_chat_and_log_gates.py b/tests/test_chat_and_log_gates.py index bff5997..e050cab 100644 --- a/tests/test_chat_and_log_gates.py +++ b/tests/test_chat_and_log_gates.py @@ -23,6 +23,7 @@ def _build_app(tmp_path): from app.db import Base from app.models import StrategicLog, User from app.routers import api as api_router + from app.routers import chat as chat_router engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/gates.db") factory = async_sessionmaker(engine, expire_on_commit=False) @@ -56,6 +57,7 @@ def _build_app(tmp_path): app = FastAPI() app.include_router(api_router.router, prefix="/api") + app.include_router(chat_router.router, prefix="/api") client = TestClient(app) return client, sign_session(1), sign_session(2)