jobs: extract shared market-context helpers from ai_log_job

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Giorgio Gilestro 2026-05-25 23:18:27 +02:00
parent ce4b19dbb8
commit 82e529b6b2
3 changed files with 104 additions and 84 deletions

View file

@ -0,0 +1,86 @@
"""Shared market-context helpers consumed by LLM-driven jobs.
Both ai_log_job and email_digest_job pull "the latest tape" the same
way most-recent quote per (group, symbol), last N hours of headlines
bucketed by category, and the running month's LLM spend. Moved here so
neither job depends on the other's internals.
"""
from __future__ import annotations
from collections import defaultdict
from datetime import timedelta
from sqlalchemy import desc, func, select
from app.db import utcnow
from app.models import AICall, Headline, Quote
from app.services.openrouter import month_start
REFERENCE_LINE = (
"S&P 7,501 (ATH) · VIX 18.0 · US 10y 4.45% · HY OAS 279bps · "
"Brent $109/bbl · Gold $4,651/oz · CPI 3.8% YoY"
)
async def latest_quotes_by_group(session) -> dict[str, list[dict]]:
"""Latest quote per (group, symbol). Skips error rows where price is null."""
sub = (
select(
Quote.group_name,
Quote.symbol,
func.max(Quote.fetched_at).label("mx"),
)
.group_by(Quote.group_name, Quote.symbol)
.subquery()
)
stmt = (
select(Quote)
.join(
sub,
(Quote.group_name == sub.c.group_name)
& (Quote.symbol == sub.c.symbol)
& (Quote.fetched_at == sub.c.mx),
)
.order_by(Quote.group_name, Quote.symbol)
)
rows = (await session.execute(stmt)).scalars().all()
by_group: dict[str, list[dict]] = defaultdict(list)
for q in rows:
by_group[q.group_name].append(dict(
symbol=q.symbol, source=q.source, label=q.label,
note="", price=q.price, currency=q.currency,
as_of=q.as_of, changes=q.changes,
))
return by_group
async def recent_headlines_by_bucket(session, hours: float = 24) -> dict[str, list[dict]]:
"""Last N hours of headlines, bucketed by category. Hard cap per
bucket to keep the prompt under ~40KB."""
cutoff = utcnow() - timedelta(hours=hours)
stmt = (
select(Headline)
.where(Headline.published_at >= cutoff)
.order_by(desc(Headline.published_at))
.limit(400)
)
rows = (await session.execute(stmt)).scalars().all()
by_bucket: dict[str, list[dict]] = defaultdict(list)
for h in rows:
if len(by_bucket[h.category]) >= 40:
continue
by_bucket[h.category].append(dict(
when=h.published_at.isoformat(),
source=h.source, title=h.title,
))
return by_bucket
async def month_spend(session) -> float:
start = month_start()
total = (await session.execute(
select(func.coalesce(func.sum(AICall.cost_usd), 0.0))
.where(AICall.called_at >= start)
)).scalar()
return float(total or 0.0)

View file

@ -4,8 +4,6 @@ and a row in the cost ledger."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict
from datetime import timedelta
import httpx import httpx
from sqlalchemy import desc, func, select from sqlalchemy import desc, func, select
@ -13,7 +11,13 @@ from sqlalchemy import desc, func, select
from app.config import get_settings from app.config import get_settings
from app.db import utcnow from app.db import utcnow
from app.jobs._helpers import job_lifecycle, log from app.jobs._helpers import job_lifecycle, log
from app.models import AICall, Headline, JobRun, Quote, StrategicLog from app.jobs._market_context import (
REFERENCE_LINE,
latest_quotes_by_group,
month_spend,
recent_headlines_by_bucket,
)
from app.models import AICall, JobRun, StrategicLog
from app.services.cadence import DEFAULT_POLICY from app.services.cadence import DEFAULT_POLICY
from app.services.openrouter import ( from app.services.openrouter import (
PROMPT_VERSION, PROMPT_VERSION,
@ -22,79 +26,9 @@ from app.services.openrouter import (
build_user_prompt, build_user_prompt,
call_llm, call_llm,
llm_configured, llm_configured,
month_start,
) )
REFERENCE_LINE = (
"S&P 7,501 (ATH) · VIX 18.0 · US 10y 4.45% · HY OAS 279bps · "
"Brent $109/bbl · Gold $4,651/oz · CPI 3.8% YoY"
)
async def _latest_quotes_by_group(session) -> dict[str, list[dict]]:
"""Latest quote per (group, symbol). Skips error rows where price is null."""
sub = (
select(
Quote.group_name,
Quote.symbol,
func.max(Quote.fetched_at).label("mx"),
)
.group_by(Quote.group_name, Quote.symbol)
.subquery()
)
stmt = (
select(Quote)
.join(
sub,
(Quote.group_name == sub.c.group_name)
& (Quote.symbol == sub.c.symbol)
& (Quote.fetched_at == sub.c.mx),
)
.order_by(Quote.group_name, Quote.symbol)
)
rows = (await session.execute(stmt)).scalars().all()
by_group: dict[str, list[dict]] = defaultdict(list)
for q in rows:
by_group[q.group_name].append(dict(
symbol=q.symbol, source=q.source, label=q.label,
note="", price=q.price, currency=q.currency,
as_of=q.as_of, changes=q.changes,
))
return by_group
async def _recent_headlines_by_bucket(session, hours: float = 24) -> dict[str, list[dict]]:
"""Last N hours of headlines, bucketed by category. Hard cap per bucket
to keep the prompt under ~40KB."""
cutoff = utcnow() - timedelta(hours=hours)
stmt = (
select(Headline)
.where(Headline.published_at >= cutoff)
.order_by(desc(Headline.published_at))
.limit(400)
)
rows = (await session.execute(stmt)).scalars().all()
by_bucket: dict[str, list[dict]] = defaultdict(list)
for h in rows:
if len(by_bucket[h.category]) >= 40:
continue
by_bucket[h.category].append(dict(
when=h.published_at.isoformat(),
source=h.source, title=h.title,
))
return by_bucket
async def _month_spend(session) -> float:
start = month_start()
total = (await session.execute(
select(func.coalesce(func.sum(AICall.cost_usd), 0.0))
.where(AICall.called_at >= start)
)).scalar()
return float(total or 0.0)
async def run() -> None: async def run() -> None:
async with job_lifecycle("ai_log_job") as (session, jr): async with job_lifecycle("ai_log_job") as (session, jr):
if jr.status == "skipped": if jr.status == "skipped":
@ -119,7 +53,7 @@ async def run() -> None:
jr.error = reason jr.error = reason
return return
spent = await _month_spend(session) spent = await month_spend(session)
if spent >= s.OPENROUTER_MONTHLY_CAP_USD: if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
log.warning("ai_log.cap_reached", spent=spent, log.warning("ai_log.cap_reached", spent=spent,
cap=s.OPENROUTER_MONTHLY_CAP_USD) cap=s.OPENROUTER_MONTHLY_CAP_USD)
@ -127,8 +61,8 @@ async def run() -> None:
jr.error = f"monthly cost cap reached (${spent:.2f})" jr.error = f"monthly cost cap reached (${spent:.2f})"
return return
quotes = await _latest_quotes_by_group(session) quotes = await latest_quotes_by_group(session)
news = await _recent_headlines_by_bucket(session) news = await recent_headlines_by_bucket(session)
if not quotes and not news: if not quotes and not news:
log.warning("ai_log.no_data_yet") log.warning("ai_log.no_data_yet")
jr.status = "skipped" jr.status = "skipped"
@ -169,7 +103,7 @@ async def run() -> None:
for tone, analysis in variants: for tone, analysis in variants:
# Re-check cost cap between variants so a runaway run is # Re-check cost cap between variants so a runaway run is
# bounded. # bounded.
spent = await _month_spend(session) spent = await month_spend(session)
if spent >= s.OPENROUTER_MONTHLY_CAP_USD: if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
log.warning("ai_log.cap_reached_midrun", log.warning("ai_log.cap_reached_midrun",
spent=spent, completed=written) spent=spent, completed=written)

View file

@ -20,11 +20,11 @@ from app import branding
from app.config import get_settings from app.config import get_settings
from app.db import utcnow from app.db import utcnow
from app.jobs._helpers import job_lifecycle, log from app.jobs._helpers import job_lifecycle, log
from app.jobs.ai_log_job import ( from app.jobs._market_context import (
REFERENCE_LINE, REFERENCE_LINE,
_latest_quotes_by_group, latest_quotes_by_group,
_recent_headlines_by_bucket, month_spend,
_month_spend, recent_headlines_by_bucket,
) )
from app.models import EmailSend, User from app.models import EmailSend, User
from app.routers.email import sign_unsubscribe_token from app.routers.email import sign_unsubscribe_token
@ -172,7 +172,7 @@ async def run() -> None:
jr.status = "skipped" jr.status = "skipped"
return return
spent = await _month_spend(session) spent = await month_spend(session)
if spent >= s.OPENROUTER_MONTHLY_CAP_USD: if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
log.warning("digest.cap_reached", spent=spent, log.warning("digest.cap_reached", spent=spent,
cap=s.OPENROUTER_MONTHLY_CAP_USD) cap=s.OPENROUTER_MONTHLY_CAP_USD)
@ -180,8 +180,8 @@ async def run() -> None:
jr.error = f"monthly cost cap reached (${spent:.2f})" jr.error = f"monthly cost cap reached (${spent:.2f})"
return return
quotes = await _latest_quotes_by_group(session) quotes = await latest_quotes_by_group(session)
news = await _recent_headlines_by_bucket( news = await recent_headlines_by_bucket(
session, hours=(168 if kind == "weekly" else 24), session, hours=(168 if kind == "weekly" else 24),
) )
ctx = dict( ctx = dict(