"""Stripe billing endpoints — checkout, webhook, customer portal. Stripe is the merchant-on-record for read.markets (after Polar/Paddle both declined the financial-media category). We delegate payment UI to Stripe-hosted Checkout and Customer Portal; the only state we keep on our side is `users.stripe_customer_id` / `users.stripe_subscription_id` so we can match incoming webhooks back to the right user. The Stripe SDK is sync; we wrap calls in `asyncio.to_thread` so the event loop doesn't block while Stripe answers. For our request volume this is more reliable than the SDK's nascent async surface. Routes - POST /api/stripe/checkout — logged-in user upgrades. Body: {cadence}. - POST /api/stripe/webhook — Stripe → us, signature-verified. - POST /api/stripe/portal — logged-in user opens the customer portal. """ from __future__ import annotations import asyncio import json from typing import Any, Literal, Optional import stripe from fastapi import APIRouter, Body, Depends, HTTPException, Request from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app import branding from app.auth import CurrentUser, require_auth from app.config import get_settings from app.db import get_session, utcnow from app.logging import get_logger from app.models import StripeEvent, User log = get_logger("stripe_billing") router = APIRouter() # Cap stored payload at 16 KiB so a hostile (or buggy) sender can't # blow up a single row. Same pattern as polar_webhook. _PAYLOAD_STORE_MAX = 16 * 1024 # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _require_configured() -> None: s = get_settings() if not s.STRIPE_API_KEY: raise HTTPException(status_code=503, detail="stripe not configured") def _price_for(cadence: str) -> str: s = get_settings() if cadence == "monthly": if not s.STRIPE_PRICE_MONTHLY: raise HTTPException(status_code=503, detail="STRIPE_PRICE_MONTHLY not set") return s.STRIPE_PRICE_MONTHLY if cadence == "annual": if not s.STRIPE_PRICE_ANNUAL: raise HTTPException(status_code=503, detail="STRIPE_PRICE_ANNUAL not set") return s.STRIPE_PRICE_ANNUAL raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'") # Rough country → currency mapping. Covers the markets we have a stated # rate for; everything else falls back to GBP (the home currency) and # Stripe handles the FX at checkout. Configure the per-currency # unit_amount on each Price's `currency_options` in the Stripe Dashboard # — we just signal which option to use here. _COUNTRY_CURRENCY: dict[str, str] = { "US": "usd", "CA": "usd", "GB": "gbp", "IM": "gbp", "JE": "gbp", "GG": "gbp", **dict.fromkeys(( "DE", "FR", "IT", "ES", "PT", "NL", "BE", "IE", "AT", "FI", "GR", "LU", "MT", "CY", "EE", "LV", "LT", "SI", "SK", "HR", ), "eur"), } # Accept-Language locale → currency, used when CF-IPCountry is absent. # Ambiguous locales (e.g. plain "fr" without region) get EUR because # that's the majority outcome. _LOCALE_CURRENCY: dict[str, str] = { "en-gb": "gbp", "en": "gbp", "en-us": "usd", "en-ca": "usd", "fr": "eur", "de": "eur", "it": "eur", "es": "eur", "pt": "eur", "nl": "eur", } def _sniff_currency(request: Request) -> str: """Best-effort currency detection for new-customer checkouts. Order: explicit Cloudflare country header, then Accept-Language (exact match then language-only). GBP as the final fallback. Only consulted when the user has no Stripe customer record yet — Stripe locks currency at customer creation, so an existing customer's currency wins regardless of the request locale. """ cc = (request.headers.get("cf-ipcountry") or "").upper() if cc in _COUNTRY_CURRENCY: return _COUNTRY_CURRENCY[cc] al = (request.headers.get("accept-language") or "").lower() first = al.split(",", 1)[0].split(";", 1)[0].strip() if first in _LOCALE_CURRENCY: return _LOCALE_CURRENCY[first] short = first.split("-", 1)[0] if short in _LOCALE_CURRENCY: return _LOCALE_CURRENCY[short] return "gbp" def _stripe_client() -> stripe.StripeClient: """Per-call client so we read the secret at request time (lets us rotate the key by editing .env + reloading without rebuilding any cached client).""" return stripe.StripeClient(get_settings().STRIPE_API_KEY) # --------------------------------------------------------------------------- # POST /api/stripe/checkout # --------------------------------------------------------------------------- class CheckoutRequest(BaseModel): cadence: Literal["monthly", "annual"] # Optional override; when omitted we sniff from request headers. # Honoured only for first-time checkouts (Stripe locks currency # to the customer at creation). currency: Optional[Literal["gbp", "usd", "eur"]] = None class CheckoutResponse(BaseModel): url: str @router.post("/api/stripe/checkout", response_model=CheckoutResponse) async def create_checkout( body: CheckoutRequest, request: Request, session: AsyncSession = Depends(get_session), cu: CurrentUser = Depends(require_auth), ) -> CheckoutResponse: _require_configured() if cu.user is None: # Admin bearer token has no User row — they shouldn't be buying. raise HTTPException(status_code=400, detail="admin token cannot purchase") user = await session.get(User, cu.user.id) if user is None: raise HTTPException(status_code=404, detail="user_not_found") price_id = _price_for(body.cadence) client = _stripe_client() # Pass `customer` if we already minted one for this user (avoids # creating duplicate Stripe customers on repeat checkouts); # otherwise let Stripe create it via `customer_email`. create_kwargs: dict[str, Any] = { "mode": "subscription", "line_items": [{"price": price_id, "quantity": 1}], "client_reference_id": str(user.id), "success_url": f"{branding.SITE_URL}/settings?upgraded=1", "cancel_url": f"{branding.SITE_URL}/pricing", # Lets us paste in a referral coupon at checkout once the # referral redemption flow ships. "allow_promotion_codes": True, } # Multi-currency: for first-time buyers (no stripe_customer_id yet) # we pass the detected/requested currency. Stripe picks the matching # `currency_options` rate configured on the Price in the Dashboard, # then locks that currency to the new customer record. Existing # customers keep their original currency regardless. if not user.stripe_customer_id: create_kwargs["currency"] = body.currency or _sniff_currency(request) # Per-cadence cooling-off treatment: # # - Annual gets a 14-day free trial. No money moves during the # trial, so the Consumer Contracts Regulations 14-day refund # question is moot (nothing paid = nothing to refund). Card is # still required at checkout so Stripe can charge on day 15. # # - Monthly bills immediately (a 14-day trial on a £7/month plan # would give away ~50% of cycle one). The Reg-36 waiver lives # on our own /pricing page as a required tick-box (see # pricing.html); we deliberately do NOT use Stripe's # consent_collection.terms_of_service here because that's an # account-wide setting and we want per-product control (and # per-product Terms URLs) as we grow. if body.cadence == "annual": create_kwargs["subscription_data"] = {"trial_period_days": 14} if user.stripe_customer_id: create_kwargs["customer"] = user.stripe_customer_id else: create_kwargs["customer_email"] = user.email try: sess = await asyncio.to_thread( client.checkout.sessions.create, params=create_kwargs, ) except stripe.StripeError as e: log.error("stripe.checkout.create_failed", user_id=user.id, error=str(e)) raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}") if not sess.url: raise HTTPException(status_code=502, detail="stripe returned no checkout URL") log.info("stripe.checkout.created", user_id=user.id, session_id=sess.id, cadence=body.cadence) return CheckoutResponse(url=sess.url) # --------------------------------------------------------------------------- # POST /api/stripe/portal # --------------------------------------------------------------------------- class PortalResponse(BaseModel): url: str @router.post("/api/stripe/portal", response_model=PortalResponse) async def create_portal_session( session: AsyncSession = Depends(get_session), cu: CurrentUser = Depends(require_auth), ) -> PortalResponse: _require_configured() if cu.user is None: raise HTTPException(status_code=400, detail="admin token has no portal") user = await session.get(User, cu.user.id) if user is None or not user.stripe_customer_id: raise HTTPException( status_code=404, detail="no_stripe_customer — start a subscription first", ) client = _stripe_client() try: portal = await asyncio.to_thread( client.billing_portal.sessions.create, params={ "customer": user.stripe_customer_id, "return_url": f"{branding.SITE_URL}/settings", }, ) except stripe.StripeError as e: log.error("stripe.portal.create_failed", user_id=user.id, error=str(e)) raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}") return PortalResponse(url=portal.url) # --------------------------------------------------------------------------- # POST /api/stripe/webhook # --------------------------------------------------------------------------- async def _find_user( session: AsyncSession, *, client_ref: str | None = None, customer_id: str | None = None, ) -> User | None: """Find the User row this event belongs to. `client_reference_id` is the most reliable join key — we set it to `str(user.id)` at checkout creation. After the first event we also know `stripe_customer_id`, which subsequent subscription / invoice events arrive carrying.""" if client_ref: try: uid = int(client_ref) except ValueError: uid = None if uid is not None: u = await session.get(User, uid) if u is not None: return u if customer_id: row = (await session.execute( select(User).where(User.stripe_customer_id == customer_id) )).scalar_one_or_none() return row return None async def _grant_paid( session: AsyncSession, user: User, *, customer_id: str | None, subscription_id: str | None, trial_end: int | None = None, status: str | None = None, ) -> None: # Capture "first paid transition" before mutating — drives the # referral-conversion call below. Skipping the convert lookup on # every renewal event saves a DB roundtrip per webhook. first_paid_transition = user.tier != "paid" user.tier = "paid" if customer_id and user.stripe_customer_id != customer_id: user.stripe_customer_id = customer_id if subscription_id and user.stripe_subscription_id != subscription_id: user.stripe_subscription_id = subscription_id # Track trial_end so the settings page can show "N days remaining". # Only populated when Stripe reports the sub as trialing — once the # status flips to active (paid for real), we clear the trial marker. if status == "trialing" and trial_end: from datetime import datetime, timezone user.stripe_trial_end_at = datetime.fromtimestamp(trial_end, tz=timezone.utc) elif status == "active": user.stripe_trial_end_at = None # Apply referral credit on the FIRST paid transition only. # convert_referral is itself idempotent (no-op on missing or # already-converted rows), so this guard is purely a perf hint. if first_paid_transition: from app.services.referral_service import convert_referral await convert_referral(session, user) async def _revoke_paid(user: User) -> None: user.tier = "free" user.stripe_subscription_id = None user.stripe_trial_end_at = None # Keep stripe_customer_id so a re-subscription matches this row. async def _handle_checkout_completed( session: AsyncSession, event_type: str, obj: dict[str, Any], ) -> None: user = await _find_user( session, client_ref=obj.get("client_reference_id"), customer_id=obj.get("customer"), ) if user is None: log.warning("stripe.user_not_found", event_type=event_type) return # checkout.session.completed doesn't carry trial_end on the session # object itself — the subscription.created event that fires right # after will carry it. We grant paid here without trial info and # let the subscription event fill in trial_end_at moments later. await _grant_paid( session, user, customer_id=obj.get("customer"), subscription_id=obj.get("subscription"), ) async def _handle_subscription_event( session: AsyncSession, event_type: str, obj: dict[str, Any], ) -> None: """customer.subscription.created / .updated — flip to paid if the Stripe-side status says the subscription is active/trialing; drop to free if it's an end-state.""" user = await _find_user(session, customer_id=obj.get("customer")) if user is None: log.warning("stripe.user_not_found", event_type=event_type, customer_id=obj.get("customer")) return status = obj.get("status") # Stripe statuses: trialing, active, past_due, canceled, unpaid, # incomplete, incomplete_expired, paused. Treat trialing/active as # paid; everything else holds tier the same until we get an explicit # subscription.deleted (which fires after the final state lands). if status in ("trialing", "active"): await _grant_paid( session, user, customer_id=obj.get("customer"), subscription_id=obj.get("id"), trial_end=obj.get("trial_end"), status=status, ) async def _handle_subscription_deleted( session: AsyncSession, event_type: str, obj: dict[str, Any], ) -> None: user = await _find_user(session, customer_id=obj.get("customer")) if user is None: log.warning("stripe.user_not_found", event_type=event_type, customer_id=obj.get("customer")) return await _revoke_paid(user) async def _handle_audit_only( session: AsyncSession, event_type: str, obj: dict[str, Any], ) -> None: """invoice.paid / invoice.payment_failed / charge.refunded — we record these in stripe_events for the audit log but the tier doesn't move until subscription.deleted fires.""" return None _HANDLERS = { "checkout.session.completed": _handle_checkout_completed, "customer.subscription.created": _handle_subscription_event, "customer.subscription.updated": _handle_subscription_event, "customer.subscription.deleted": _handle_subscription_deleted, "invoice.paid": _handle_audit_only, "invoice.payment_failed": _handle_audit_only, "charge.refunded": _handle_audit_only, } @router.post("/api/stripe/webhook") async def stripe_webhook( request: Request, session: AsyncSession = Depends(get_session), ) -> dict[str, str]: s = get_settings() if not s.STRIPE_WEBHOOK_SECRET: raise HTTPException(status_code=503, detail="stripe webhook not configured") sig = request.headers.get("stripe-signature", "") if not sig: raise HTTPException(status_code=400, detail="missing stripe-signature header") body = await request.body() # construct_event handles HMAC verification + timestamp tolerance. # We then re-parse the body as plain JSON for handler dispatch — # the Stripe SDK's StripeObject doesn't expose dict.get(), and # round-tripping through json gives us simple, typed-dict access. try: stripe.Webhook.construct_event( payload=body, sig_header=sig, secret=s.STRIPE_WEBHOOK_SECRET, ) except stripe.SignatureVerificationError: raise HTTPException(status_code=401, detail="bad signature") except ValueError: raise HTTPException(status_code=400, detail="invalid payload") envelope = json.loads(body) event_id = envelope.get("id") or "" event_type = envelope.get("type") or "unknown" obj = (envelope.get("data") or {}).get("object") or {} if not event_id: raise HTTPException(status_code=400, detail="event missing id") # Idempotency: insert audit row first. UNIQUE on event_id makes a # replay of the same Stripe event id a no-op (Stripe retries on # non-2xx, so always 2xx after first successful processing). audit = StripeEvent( event_id=event_id, event_type=event_type, received_at=utcnow(), payload=body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX], ) session.add(audit) try: await session.flush() except IntegrityError: await session.rollback() log.info("stripe.duplicate_delivery", event_id=event_id, type=event_type) return {"status": "duplicate"} handler = _HANDLERS.get(event_type) if handler is None: audit.processed_at = utcnow() await session.commit() log.info("stripe.event_unhandled", type=event_type, id=event_id) return {"status": "ignored"} try: await handler(session, event_type, obj) except Exception as e: audit.error = str(e)[:1024] await session.commit() log.exception("stripe.handler_error", type=event_type, id=event_id) # Ack 200 — we don't want Stripe retrying a handler that broke # the same way on every delivery. An operator triages from the # `error` column. return {"status": "handler_error"} audit.processed_at = utcnow() await session.commit() log.info("stripe.processed", type=event_type, id=event_id) return {"status": "ok"}