jobs: extract shared market-context helpers from ai_log_job
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ce4b19dbb8
commit
82e529b6b2
3 changed files with 104 additions and 84 deletions
86
app/jobs/_market_context.py
Normal file
86
app/jobs/_market_context.py
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""Shared market-context helpers consumed by LLM-driven jobs.
|
||||||
|
|
||||||
|
Both ai_log_job and email_digest_job pull "the latest tape" the same
|
||||||
|
way — most-recent quote per (group, symbol), last N hours of headlines
|
||||||
|
bucketed by category, and the running month's LLM spend. Moved here so
|
||||||
|
neither job depends on the other's internals.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import desc, func, select
|
||||||
|
|
||||||
|
from app.db import utcnow
|
||||||
|
from app.models import AICall, Headline, Quote
|
||||||
|
from app.services.openrouter import month_start
|
||||||
|
|
||||||
|
|
||||||
|
REFERENCE_LINE = (
|
||||||
|
"S&P 7,501 (ATH) · VIX 18.0 · US 10y 4.45% · HY OAS 279bps · "
|
||||||
|
"Brent $109/bbl · Gold $4,651/oz · CPI 3.8% YoY"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def latest_quotes_by_group(session) -> dict[str, list[dict]]:
|
||||||
|
"""Latest quote per (group, symbol). Skips error rows where price is null."""
|
||||||
|
sub = (
|
||||||
|
select(
|
||||||
|
Quote.group_name,
|
||||||
|
Quote.symbol,
|
||||||
|
func.max(Quote.fetched_at).label("mx"),
|
||||||
|
)
|
||||||
|
.group_by(Quote.group_name, Quote.symbol)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
stmt = (
|
||||||
|
select(Quote)
|
||||||
|
.join(
|
||||||
|
sub,
|
||||||
|
(Quote.group_name == sub.c.group_name)
|
||||||
|
& (Quote.symbol == sub.c.symbol)
|
||||||
|
& (Quote.fetched_at == sub.c.mx),
|
||||||
|
)
|
||||||
|
.order_by(Quote.group_name, Quote.symbol)
|
||||||
|
)
|
||||||
|
rows = (await session.execute(stmt)).scalars().all()
|
||||||
|
by_group: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
for q in rows:
|
||||||
|
by_group[q.group_name].append(dict(
|
||||||
|
symbol=q.symbol, source=q.source, label=q.label,
|
||||||
|
note="", price=q.price, currency=q.currency,
|
||||||
|
as_of=q.as_of, changes=q.changes,
|
||||||
|
))
|
||||||
|
return by_group
|
||||||
|
|
||||||
|
|
||||||
|
async def recent_headlines_by_bucket(session, hours: float = 24) -> dict[str, list[dict]]:
|
||||||
|
"""Last N hours of headlines, bucketed by category. Hard cap per
|
||||||
|
bucket to keep the prompt under ~40KB."""
|
||||||
|
cutoff = utcnow() - timedelta(hours=hours)
|
||||||
|
stmt = (
|
||||||
|
select(Headline)
|
||||||
|
.where(Headline.published_at >= cutoff)
|
||||||
|
.order_by(desc(Headline.published_at))
|
||||||
|
.limit(400)
|
||||||
|
)
|
||||||
|
rows = (await session.execute(stmt)).scalars().all()
|
||||||
|
by_bucket: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
for h in rows:
|
||||||
|
if len(by_bucket[h.category]) >= 40:
|
||||||
|
continue
|
||||||
|
by_bucket[h.category].append(dict(
|
||||||
|
when=h.published_at.isoformat(),
|
||||||
|
source=h.source, title=h.title,
|
||||||
|
))
|
||||||
|
return by_bucket
|
||||||
|
|
||||||
|
|
||||||
|
async def month_spend(session) -> float:
|
||||||
|
start = month_start()
|
||||||
|
total = (await session.execute(
|
||||||
|
select(func.coalesce(func.sum(AICall.cost_usd), 0.0))
|
||||||
|
.where(AICall.called_at >= start)
|
||||||
|
)).scalar()
|
||||||
|
return float(total or 0.0)
|
||||||
|
|
@ -4,8 +4,6 @@ and a row in the cost ledger."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from sqlalchemy import desc, func, select
|
from sqlalchemy import desc, func, select
|
||||||
|
|
@ -13,7 +11,13 @@ from sqlalchemy import desc, func, select
|
||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
from app.db import utcnow
|
from app.db import utcnow
|
||||||
from app.jobs._helpers import job_lifecycle, log
|
from app.jobs._helpers import job_lifecycle, log
|
||||||
from app.models import AICall, Headline, JobRun, Quote, StrategicLog
|
from app.jobs._market_context import (
|
||||||
|
REFERENCE_LINE,
|
||||||
|
latest_quotes_by_group,
|
||||||
|
month_spend,
|
||||||
|
recent_headlines_by_bucket,
|
||||||
|
)
|
||||||
|
from app.models import AICall, JobRun, StrategicLog
|
||||||
from app.services.cadence import DEFAULT_POLICY
|
from app.services.cadence import DEFAULT_POLICY
|
||||||
from app.services.openrouter import (
|
from app.services.openrouter import (
|
||||||
PROMPT_VERSION,
|
PROMPT_VERSION,
|
||||||
|
|
@ -22,79 +26,9 @@ from app.services.openrouter import (
|
||||||
build_user_prompt,
|
build_user_prompt,
|
||||||
call_llm,
|
call_llm,
|
||||||
llm_configured,
|
llm_configured,
|
||||||
month_start,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
REFERENCE_LINE = (
|
|
||||||
"S&P 7,501 (ATH) · VIX 18.0 · US 10y 4.45% · HY OAS 279bps · "
|
|
||||||
"Brent $109/bbl · Gold $4,651/oz · CPI 3.8% YoY"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _latest_quotes_by_group(session) -> dict[str, list[dict]]:
|
|
||||||
"""Latest quote per (group, symbol). Skips error rows where price is null."""
|
|
||||||
sub = (
|
|
||||||
select(
|
|
||||||
Quote.group_name,
|
|
||||||
Quote.symbol,
|
|
||||||
func.max(Quote.fetched_at).label("mx"),
|
|
||||||
)
|
|
||||||
.group_by(Quote.group_name, Quote.symbol)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
stmt = (
|
|
||||||
select(Quote)
|
|
||||||
.join(
|
|
||||||
sub,
|
|
||||||
(Quote.group_name == sub.c.group_name)
|
|
||||||
& (Quote.symbol == sub.c.symbol)
|
|
||||||
& (Quote.fetched_at == sub.c.mx),
|
|
||||||
)
|
|
||||||
.order_by(Quote.group_name, Quote.symbol)
|
|
||||||
)
|
|
||||||
rows = (await session.execute(stmt)).scalars().all()
|
|
||||||
by_group: dict[str, list[dict]] = defaultdict(list)
|
|
||||||
for q in rows:
|
|
||||||
by_group[q.group_name].append(dict(
|
|
||||||
symbol=q.symbol, source=q.source, label=q.label,
|
|
||||||
note="", price=q.price, currency=q.currency,
|
|
||||||
as_of=q.as_of, changes=q.changes,
|
|
||||||
))
|
|
||||||
return by_group
|
|
||||||
|
|
||||||
|
|
||||||
async def _recent_headlines_by_bucket(session, hours: float = 24) -> dict[str, list[dict]]:
|
|
||||||
"""Last N hours of headlines, bucketed by category. Hard cap per bucket
|
|
||||||
to keep the prompt under ~40KB."""
|
|
||||||
cutoff = utcnow() - timedelta(hours=hours)
|
|
||||||
stmt = (
|
|
||||||
select(Headline)
|
|
||||||
.where(Headline.published_at >= cutoff)
|
|
||||||
.order_by(desc(Headline.published_at))
|
|
||||||
.limit(400)
|
|
||||||
)
|
|
||||||
rows = (await session.execute(stmt)).scalars().all()
|
|
||||||
by_bucket: dict[str, list[dict]] = defaultdict(list)
|
|
||||||
for h in rows:
|
|
||||||
if len(by_bucket[h.category]) >= 40:
|
|
||||||
continue
|
|
||||||
by_bucket[h.category].append(dict(
|
|
||||||
when=h.published_at.isoformat(),
|
|
||||||
source=h.source, title=h.title,
|
|
||||||
))
|
|
||||||
return by_bucket
|
|
||||||
|
|
||||||
|
|
||||||
async def _month_spend(session) -> float:
|
|
||||||
start = month_start()
|
|
||||||
total = (await session.execute(
|
|
||||||
select(func.coalesce(func.sum(AICall.cost_usd), 0.0))
|
|
||||||
.where(AICall.called_at >= start)
|
|
||||||
)).scalar()
|
|
||||||
return float(total or 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
async def run() -> None:
|
async def run() -> None:
|
||||||
async with job_lifecycle("ai_log_job") as (session, jr):
|
async with job_lifecycle("ai_log_job") as (session, jr):
|
||||||
if jr.status == "skipped":
|
if jr.status == "skipped":
|
||||||
|
|
@ -119,7 +53,7 @@ async def run() -> None:
|
||||||
jr.error = reason
|
jr.error = reason
|
||||||
return
|
return
|
||||||
|
|
||||||
spent = await _month_spend(session)
|
spent = await month_spend(session)
|
||||||
if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
|
if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
|
||||||
log.warning("ai_log.cap_reached", spent=spent,
|
log.warning("ai_log.cap_reached", spent=spent,
|
||||||
cap=s.OPENROUTER_MONTHLY_CAP_USD)
|
cap=s.OPENROUTER_MONTHLY_CAP_USD)
|
||||||
|
|
@ -127,8 +61,8 @@ async def run() -> None:
|
||||||
jr.error = f"monthly cost cap reached (${spent:.2f})"
|
jr.error = f"monthly cost cap reached (${spent:.2f})"
|
||||||
return
|
return
|
||||||
|
|
||||||
quotes = await _latest_quotes_by_group(session)
|
quotes = await latest_quotes_by_group(session)
|
||||||
news = await _recent_headlines_by_bucket(session)
|
news = await recent_headlines_by_bucket(session)
|
||||||
if not quotes and not news:
|
if not quotes and not news:
|
||||||
log.warning("ai_log.no_data_yet")
|
log.warning("ai_log.no_data_yet")
|
||||||
jr.status = "skipped"
|
jr.status = "skipped"
|
||||||
|
|
@ -169,7 +103,7 @@ async def run() -> None:
|
||||||
for tone, analysis in variants:
|
for tone, analysis in variants:
|
||||||
# Re-check cost cap between variants so a runaway run is
|
# Re-check cost cap between variants so a runaway run is
|
||||||
# bounded.
|
# bounded.
|
||||||
spent = await _month_spend(session)
|
spent = await month_spend(session)
|
||||||
if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
|
if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
|
||||||
log.warning("ai_log.cap_reached_midrun",
|
log.warning("ai_log.cap_reached_midrun",
|
||||||
spent=spent, completed=written)
|
spent=spent, completed=written)
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,11 @@ from app import branding
|
||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
from app.db import utcnow
|
from app.db import utcnow
|
||||||
from app.jobs._helpers import job_lifecycle, log
|
from app.jobs._helpers import job_lifecycle, log
|
||||||
from app.jobs.ai_log_job import (
|
from app.jobs._market_context import (
|
||||||
REFERENCE_LINE,
|
REFERENCE_LINE,
|
||||||
_latest_quotes_by_group,
|
latest_quotes_by_group,
|
||||||
_recent_headlines_by_bucket,
|
month_spend,
|
||||||
_month_spend,
|
recent_headlines_by_bucket,
|
||||||
)
|
)
|
||||||
from app.models import EmailSend, User
|
from app.models import EmailSend, User
|
||||||
from app.routers.email import sign_unsubscribe_token
|
from app.routers.email import sign_unsubscribe_token
|
||||||
|
|
@ -172,7 +172,7 @@ async def run() -> None:
|
||||||
jr.status = "skipped"
|
jr.status = "skipped"
|
||||||
return
|
return
|
||||||
|
|
||||||
spent = await _month_spend(session)
|
spent = await month_spend(session)
|
||||||
if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
|
if spent >= s.OPENROUTER_MONTHLY_CAP_USD:
|
||||||
log.warning("digest.cap_reached", spent=spent,
|
log.warning("digest.cap_reached", spent=spent,
|
||||||
cap=s.OPENROUTER_MONTHLY_CAP_USD)
|
cap=s.OPENROUTER_MONTHLY_CAP_USD)
|
||||||
|
|
@ -180,8 +180,8 @@ async def run() -> None:
|
||||||
jr.error = f"monthly cost cap reached (${spent:.2f})"
|
jr.error = f"monthly cost cap reached (${spent:.2f})"
|
||||||
return
|
return
|
||||||
|
|
||||||
quotes = await _latest_quotes_by_group(session)
|
quotes = await latest_quotes_by_group(session)
|
||||||
news = await _recent_headlines_by_bucket(
|
news = await recent_headlines_by_bucket(
|
||||||
session, hours=(168 if kind == "weekly" else 24),
|
session, hours=(168 if kind == "weekly" else 24),
|
||||||
)
|
)
|
||||||
ctx = dict(
|
ctx = dict(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue