read.markets/app/routers/stripe_billing.py
Giorgio Gilestro 83995e96c8 stripe: detect buyer currency at checkout (GBP/USD/EUR)
Pass `currency` to Stripe checkout for first-time buyers so Stripe
picks the matching `currency_options` rate configured on the Price
in the Dashboard (multi-currency Prices: one Price, per-currency
unit_amount). Operator configures the rates on existing Prices
prod_UaZ0xCpCboUGCN/price_*; this commit is the application-side
signal.

Currency precedence: explicit request body > Cloudflare cf-ipcountry
header > Accept-Language locale > GBP fallback. Only honoured when
the user has no stripe_customer_id yet — Stripe locks currency to
the customer record at first checkout, so existing customers keep
their original currency (they can switch via the portal).

Adds 4 tests: sniffed currency on new customer, body override beats
sniff, currency omitted for existing customer, and unit-tests for
the sniffing fallback chain.
2026-05-28 12:42:40 +02:00

490 lines
18 KiB
Python

"""Stripe billing endpoints — checkout, webhook, customer portal.
Stripe is the merchant-on-record for read.markets (after Polar/Paddle
both declined the financial-media category). We delegate payment UI to
Stripe-hosted Checkout and Customer Portal; the only state we keep on
our side is `users.stripe_customer_id` / `users.stripe_subscription_id`
so we can match incoming webhooks back to the right user.
The Stripe SDK is sync; we wrap calls in `asyncio.to_thread` so the
event loop doesn't block while Stripe answers. For our request volume
this is more reliable than the SDK's nascent async surface.
Routes
- POST /api/stripe/checkout — logged-in user upgrades. Body: {cadence}.
- POST /api/stripe/webhook — Stripe → us, signature-verified.
- POST /api/stripe/portal — logged-in user opens the customer portal.
"""
from __future__ import annotations
import asyncio
import json
from typing import Any, Literal, Optional
import stripe
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app import branding
from app.auth import CurrentUser, require_auth
from app.config import get_settings
from app.db import get_session, utcnow
from app.logging import get_logger
from app.models import StripeEvent, User
log = get_logger("stripe_billing")
router = APIRouter()
# Cap stored payload at 16 KiB so a hostile (or buggy) sender can't
# blow up a single row. Same pattern as polar_webhook.
_PAYLOAD_STORE_MAX = 16 * 1024
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _require_configured() -> None:
s = get_settings()
if not s.STRIPE_API_KEY:
raise HTTPException(status_code=503, detail="stripe not configured")
def _price_for(cadence: str) -> str:
s = get_settings()
if cadence == "monthly":
if not s.STRIPE_PRICE_MONTHLY:
raise HTTPException(status_code=503, detail="STRIPE_PRICE_MONTHLY not set")
return s.STRIPE_PRICE_MONTHLY
if cadence == "annual":
if not s.STRIPE_PRICE_ANNUAL:
raise HTTPException(status_code=503, detail="STRIPE_PRICE_ANNUAL not set")
return s.STRIPE_PRICE_ANNUAL
raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'")
# Rough country → currency mapping. Covers the markets we have a stated
# rate for; everything else falls back to GBP (the home currency) and
# Stripe handles the FX at checkout. Configure the per-currency
# unit_amount on each Price's `currency_options` in the Stripe Dashboard
# — we just signal which option to use here.
_COUNTRY_CURRENCY: dict[str, str] = {
"US": "usd", "CA": "usd",
"GB": "gbp", "IM": "gbp", "JE": "gbp", "GG": "gbp",
**dict.fromkeys((
"DE", "FR", "IT", "ES", "PT", "NL", "BE", "IE", "AT", "FI",
"GR", "LU", "MT", "CY", "EE", "LV", "LT", "SI", "SK", "HR",
), "eur"),
}
# Accept-Language locale → currency, used when CF-IPCountry is absent.
# Ambiguous locales (e.g. plain "fr" without region) get EUR because
# that's the majority outcome.
_LOCALE_CURRENCY: dict[str, str] = {
"en-gb": "gbp", "en": "gbp",
"en-us": "usd", "en-ca": "usd",
"fr": "eur", "de": "eur", "it": "eur", "es": "eur",
"pt": "eur", "nl": "eur",
}
def _sniff_currency(request: Request) -> str:
"""Best-effort currency detection for new-customer checkouts.
Order: explicit Cloudflare country header, then Accept-Language
(exact match then language-only). GBP as the final fallback. Only
consulted when the user has no Stripe customer record yet — Stripe
locks currency at customer creation, so an existing customer's
currency wins regardless of the request locale.
"""
cc = (request.headers.get("cf-ipcountry") or "").upper()
if cc in _COUNTRY_CURRENCY:
return _COUNTRY_CURRENCY[cc]
al = (request.headers.get("accept-language") or "").lower()
first = al.split(",", 1)[0].split(";", 1)[0].strip()
if first in _LOCALE_CURRENCY:
return _LOCALE_CURRENCY[first]
short = first.split("-", 1)[0]
if short in _LOCALE_CURRENCY:
return _LOCALE_CURRENCY[short]
return "gbp"
def _stripe_client() -> stripe.StripeClient:
"""Per-call client so we read the secret at request time (lets us
rotate the key by editing .env + reloading without rebuilding any
cached client)."""
return stripe.StripeClient(get_settings().STRIPE_API_KEY)
# ---------------------------------------------------------------------------
# POST /api/stripe/checkout
# ---------------------------------------------------------------------------
class CheckoutRequest(BaseModel):
cadence: Literal["monthly", "annual"]
# Optional override; when omitted we sniff from request headers.
# Honoured only for first-time checkouts (Stripe locks currency
# to the customer at creation).
currency: Optional[Literal["gbp", "usd", "eur"]] = None
class CheckoutResponse(BaseModel):
url: str
@router.post("/api/stripe/checkout", response_model=CheckoutResponse)
async def create_checkout(
body: CheckoutRequest,
request: Request,
session: AsyncSession = Depends(get_session),
cu: CurrentUser = Depends(require_auth),
) -> CheckoutResponse:
_require_configured()
if cu.user is None:
# Admin bearer token has no User row — they shouldn't be buying.
raise HTTPException(status_code=400, detail="admin token cannot purchase")
user = await session.get(User, cu.user.id)
if user is None:
raise HTTPException(status_code=404, detail="user_not_found")
price_id = _price_for(body.cadence)
client = _stripe_client()
# Pass `customer` if we already minted one for this user (avoids
# creating duplicate Stripe customers on repeat checkouts);
# otherwise let Stripe create it via `customer_email`.
create_kwargs: dict[str, Any] = {
"mode": "subscription",
"line_items": [{"price": price_id, "quantity": 1}],
"client_reference_id": str(user.id),
"success_url": f"{branding.SITE_URL}/settings?upgraded=1",
"cancel_url": f"{branding.SITE_URL}/pricing",
# Lets us paste in a referral coupon at checkout once the
# referral redemption flow ships.
"allow_promotion_codes": True,
}
# Multi-currency: for first-time buyers (no stripe_customer_id yet)
# we pass the detected/requested currency. Stripe picks the matching
# `currency_options` rate configured on the Price in the Dashboard,
# then locks that currency to the new customer record. Existing
# customers keep their original currency regardless.
if not user.stripe_customer_id:
create_kwargs["currency"] = body.currency or _sniff_currency(request)
# Per-cadence cooling-off treatment:
#
# - Annual gets a 14-day free trial. No money moves during the
# trial, so the Consumer Contracts Regulations 14-day refund
# question is moot (nothing paid = nothing to refund). Card is
# still required at checkout so Stripe can charge on day 15.
#
# - Monthly bills immediately (a 14-day trial on a £7/month plan
# would give away ~50% of cycle one). The Reg-36 waiver lives
# on our own /pricing page as a required tick-box (see
# pricing.html); we deliberately do NOT use Stripe's
# consent_collection.terms_of_service here because that's an
# account-wide setting and we want per-product control (and
# per-product Terms URLs) as we grow.
if body.cadence == "annual":
create_kwargs["subscription_data"] = {"trial_period_days": 14}
if user.stripe_customer_id:
create_kwargs["customer"] = user.stripe_customer_id
else:
create_kwargs["customer_email"] = user.email
try:
sess = await asyncio.to_thread(
client.checkout.sessions.create, params=create_kwargs,
)
except stripe.StripeError as e:
log.error("stripe.checkout.create_failed", user_id=user.id, error=str(e))
raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
if not sess.url:
raise HTTPException(status_code=502, detail="stripe returned no checkout URL")
log.info("stripe.checkout.created", user_id=user.id, session_id=sess.id,
cadence=body.cadence)
return CheckoutResponse(url=sess.url)
# ---------------------------------------------------------------------------
# POST /api/stripe/portal
# ---------------------------------------------------------------------------
class PortalResponse(BaseModel):
url: str
@router.post("/api/stripe/portal", response_model=PortalResponse)
async def create_portal_session(
session: AsyncSession = Depends(get_session),
cu: CurrentUser = Depends(require_auth),
) -> PortalResponse:
_require_configured()
if cu.user is None:
raise HTTPException(status_code=400, detail="admin token has no portal")
user = await session.get(User, cu.user.id)
if user is None or not user.stripe_customer_id:
raise HTTPException(
status_code=404,
detail="no_stripe_customer — start a subscription first",
)
client = _stripe_client()
try:
portal = await asyncio.to_thread(
client.billing_portal.sessions.create,
params={
"customer": user.stripe_customer_id,
"return_url": f"{branding.SITE_URL}/settings",
},
)
except stripe.StripeError as e:
log.error("stripe.portal.create_failed", user_id=user.id, error=str(e))
raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
return PortalResponse(url=portal.url)
# ---------------------------------------------------------------------------
# POST /api/stripe/webhook
# ---------------------------------------------------------------------------
async def _find_user(
session: AsyncSession,
*,
client_ref: str | None = None,
customer_id: str | None = None,
) -> User | None:
"""Find the User row this event belongs to.
`client_reference_id` is the most reliable join key — we set it
to `str(user.id)` at checkout creation. After the first event we
also know `stripe_customer_id`, which subsequent subscription /
invoice events arrive carrying."""
if client_ref:
try:
uid = int(client_ref)
except ValueError:
uid = None
if uid is not None:
u = await session.get(User, uid)
if u is not None:
return u
if customer_id:
row = (await session.execute(
select(User).where(User.stripe_customer_id == customer_id)
)).scalar_one_or_none()
return row
return None
async def _grant_paid(
session: AsyncSession,
user: User,
*,
customer_id: str | None,
subscription_id: str | None,
trial_end: int | None = None,
status: str | None = None,
) -> None:
# Capture "first paid transition" before mutating — drives the
# referral-conversion call below. Skipping the convert lookup on
# every renewal event saves a DB roundtrip per webhook.
first_paid_transition = user.tier != "paid"
user.tier = "paid"
if customer_id and user.stripe_customer_id != customer_id:
user.stripe_customer_id = customer_id
if subscription_id and user.stripe_subscription_id != subscription_id:
user.stripe_subscription_id = subscription_id
# Track trial_end so the settings page can show "N days remaining".
# Only populated when Stripe reports the sub as trialing — once the
# status flips to active (paid for real), we clear the trial marker.
if status == "trialing" and trial_end:
from datetime import datetime, timezone
user.stripe_trial_end_at = datetime.fromtimestamp(trial_end, tz=timezone.utc)
elif status == "active":
user.stripe_trial_end_at = None
# Apply referral credit on the FIRST paid transition only.
# convert_referral is itself idempotent (no-op on missing or
# already-converted rows), so this guard is purely a perf hint.
if first_paid_transition:
from app.services.referral_service import convert_referral
await convert_referral(session, user)
async def _revoke_paid(user: User) -> None:
user.tier = "free"
user.stripe_subscription_id = None
user.stripe_trial_end_at = None
# Keep stripe_customer_id so a re-subscription matches this row.
async def _handle_checkout_completed(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
user = await _find_user(
session,
client_ref=obj.get("client_reference_id"),
customer_id=obj.get("customer"),
)
if user is None:
log.warning("stripe.user_not_found", event_type=event_type)
return
# checkout.session.completed doesn't carry trial_end on the session
# object itself — the subscription.created event that fires right
# after will carry it. We grant paid here without trial info and
# let the subscription event fill in trial_end_at moments later.
await _grant_paid(
session,
user,
customer_id=obj.get("customer"),
subscription_id=obj.get("subscription"),
)
async def _handle_subscription_event(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
"""customer.subscription.created / .updated — flip to paid if the
Stripe-side status says the subscription is active/trialing; drop
to free if it's an end-state."""
user = await _find_user(session, customer_id=obj.get("customer"))
if user is None:
log.warning("stripe.user_not_found", event_type=event_type,
customer_id=obj.get("customer"))
return
status = obj.get("status")
# Stripe statuses: trialing, active, past_due, canceled, unpaid,
# incomplete, incomplete_expired, paused. Treat trialing/active as
# paid; everything else holds tier the same until we get an explicit
# subscription.deleted (which fires after the final state lands).
if status in ("trialing", "active"):
await _grant_paid(
session,
user,
customer_id=obj.get("customer"),
subscription_id=obj.get("id"),
trial_end=obj.get("trial_end"),
status=status,
)
async def _handle_subscription_deleted(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
user = await _find_user(session, customer_id=obj.get("customer"))
if user is None:
log.warning("stripe.user_not_found", event_type=event_type,
customer_id=obj.get("customer"))
return
await _revoke_paid(user)
async def _handle_audit_only(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
"""invoice.paid / invoice.payment_failed / charge.refunded — we
record these in stripe_events for the audit log but the tier doesn't
move until subscription.deleted fires."""
return None
_HANDLERS = {
"checkout.session.completed": _handle_checkout_completed,
"customer.subscription.created": _handle_subscription_event,
"customer.subscription.updated": _handle_subscription_event,
"customer.subscription.deleted": _handle_subscription_deleted,
"invoice.paid": _handle_audit_only,
"invoice.payment_failed": _handle_audit_only,
"charge.refunded": _handle_audit_only,
}
@router.post("/api/stripe/webhook")
async def stripe_webhook(
request: Request,
session: AsyncSession = Depends(get_session),
) -> dict[str, str]:
s = get_settings()
if not s.STRIPE_WEBHOOK_SECRET:
raise HTTPException(status_code=503, detail="stripe webhook not configured")
sig = request.headers.get("stripe-signature", "")
if not sig:
raise HTTPException(status_code=400, detail="missing stripe-signature header")
body = await request.body()
# construct_event handles HMAC verification + timestamp tolerance.
# We then re-parse the body as plain JSON for handler dispatch —
# the Stripe SDK's StripeObject doesn't expose dict.get(), and
# round-tripping through json gives us simple, typed-dict access.
try:
stripe.Webhook.construct_event(
payload=body, sig_header=sig, secret=s.STRIPE_WEBHOOK_SECRET,
)
except stripe.SignatureVerificationError:
raise HTTPException(status_code=401, detail="bad signature")
except ValueError:
raise HTTPException(status_code=400, detail="invalid payload")
envelope = json.loads(body)
event_id = envelope.get("id") or ""
event_type = envelope.get("type") or "unknown"
obj = (envelope.get("data") or {}).get("object") or {}
if not event_id:
raise HTTPException(status_code=400, detail="event missing id")
# Idempotency: insert audit row first. UNIQUE on event_id makes a
# replay of the same Stripe event id a no-op (Stripe retries on
# non-2xx, so always 2xx after first successful processing).
audit = StripeEvent(
event_id=event_id,
event_type=event_type,
received_at=utcnow(),
payload=body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX],
)
session.add(audit)
try:
await session.flush()
except IntegrityError:
await session.rollback()
log.info("stripe.duplicate_delivery", event_id=event_id, type=event_type)
return {"status": "duplicate"}
handler = _HANDLERS.get(event_type)
if handler is None:
audit.processed_at = utcnow()
await session.commit()
log.info("stripe.event_unhandled", type=event_type, id=event_id)
return {"status": "ignored"}
try:
await handler(session, event_type, obj)
except Exception as e:
audit.error = str(e)[:1024]
await session.commit()
log.exception("stripe.handler_error", type=event_type, id=event_id)
# Ack 200 — we don't want Stripe retrying a handler that broke
# the same way on every delivery. An operator triages from the
# `error` column.
return {"status": "handler_error"}
audit.processed_at = utcnow()
await session.commit()
log.info("stripe.processed", type=event_type, id=event_id)
return {"status": "ok"}