diff --git a/alembic/versions/0019_stripe.py b/alembic/versions/0019_stripe.py new file mode 100644 index 0000000..3ea4018 --- /dev/null +++ b/alembic/versions/0019_stripe.py @@ -0,0 +1,56 @@ +"""stripe integration: users.stripe_customer_id / stripe_subscription_id, +stripe_events table. + +Revision ID: 0019 +Revises: 0018 +Create Date: 2026-05-26 +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +revision: str = "0019" +down_revision: Union[str, None] = "0018" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("stripe_customer_id", sa.String(length=64), nullable=True), + ) + op.add_column( + "users", + sa.Column("stripe_subscription_id", sa.String(length=64), nullable=True), + ) + op.create_unique_constraint( + "uq_users_stripe_customer", "users", ["stripe_customer_id"], + ) + + op.create_table( + "stripe_events", + sa.Column("id", sa.BigInteger(), autoincrement=True, primary_key=True), + sa.Column("event_id", sa.String(length=128), nullable=False), + sa.Column("event_type", sa.String(length=64), nullable=False), + sa.Column("received_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("processed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("error", sa.Text(), nullable=True), + sa.Column("payload", sa.Text(), nullable=False), + sa.UniqueConstraint("event_id", name="uq_stripe_events_event_id"), + ) + op.create_index( + "ix_stripe_events_type_received", + "stripe_events", + ["event_type", "received_at"], + ) + + +def downgrade() -> None: + op.drop_index("ix_stripe_events_type_received", table_name="stripe_events") + op.drop_table("stripe_events") + op.drop_constraint("uq_users_stripe_customer", "users", type_="unique") + op.drop_column("users", "stripe_subscription_id") + op.drop_column("users", "stripe_customer_id") diff --git a/app/config.py b/app/config.py index 4ff67d6..0aabb1d 100644 --- a/app/config.py +++ b/app/config.py @@ -99,6 +99,16 @@ class Settings(BaseSettings): POLAR_WEBHOOK_SECRET: str = "" POLAR_API_KEY: str = "" + # Stripe (merchant-on-record for read.markets after Polar/Paddle + # both declined the financial-media category). Test-mode keys are + # `sk_test_*` / `whsec_*`; live-mode keys are `sk_live_*` — swap at + # GA cutover. Empty values make the corresponding endpoints 503 so + # a misconfig is loud rather than silently accepting unsigned events. + STRIPE_API_KEY: str = "" + STRIPE_WEBHOOK_SECRET: str = "" + STRIPE_PRICE_MONTHLY: str = "" # price_xxx for £7/month subscription + STRIPE_PRICE_ANNUAL: str = "" # price_xxx for £70/year subscription + # Config file locations (overridable for tests) BASELINE_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "default.toml") PORTFOLIO_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "portfolio.toml") diff --git a/app/main.py b/app/main.py index d9c491e..9499a34 100644 --- a/app/main.py +++ b/app/main.py @@ -23,6 +23,7 @@ from app.routers import email as email_router from app.routers import pages as pages_router from app.routers import polar_webhook as polar_webhook_router from app.routers import public as public_router +from app.routers import stripe_billing as stripe_router from app.routers import sync as sync_router from app.routers import universe as universe_router from app.services.feeds_bootstrap import bootstrap_feeds @@ -93,6 +94,9 @@ app.include_router(sync_router.router, tags=["portfolio-sync"]) # `/api/polar/webhook` is set on the route itself so the URL Polar # stores remains stable even if api_router's prefix ever moves. app.include_router(polar_webhook_router.router, tags=["polar-webhook"]) +# Stripe billing (checkout, portal, webhook). Auth lives per-route: +# checkout + portal require_auth, webhook is signature-gated. +app.include_router(stripe_router.router, tags=["stripe-billing"]) # Public router (no auth dep) before pages_router so the marketing/legal # paths can never collide with future authenticated routes. app.include_router(public_router.router) diff --git a/app/models.py b/app/models.py index 643b8d8..2140f6b 100644 --- a/app/models.py +++ b/app/models.py @@ -194,11 +194,18 @@ class User(Base): # we cancel against from /settings. polar_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True) polar_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True) + # Stripe (merchant-on-record for read.markets). Populated on the + # first checkout.session.completed event via client_reference_id; + # used thereafter to match incoming subscription/invoice events + # back to this row. + stripe_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True) + stripe_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True) __table_args__ = ( UniqueConstraint("email", name="uq_users_email"), UniqueConstraint("referral_code", name="uq_users_referral_code"), UniqueConstraint("polar_customer_id", name="uq_users_polar_customer"), + UniqueConstraint("stripe_customer_id", name="uq_users_stripe_customer"), ) @@ -385,3 +392,29 @@ class PolarEvent(Base): UniqueConstraint("event_id", name="uq_polar_events_event_id"), Index("ix_polar_events_type_received", "event_type", "received_at"), ) + + +class StripeEvent(Base): + """Audit + idempotency table for inbound Stripe webhook deliveries. + + Same shape and purpose as PolarEvent — Stripe's `event.id` plays the + same role as Standard Webhooks' `webhook-id`. We keep the tables + distinct (rather than a single 'webhook_events' table) so an + operator can look at the audit trail per processor without filtering + on a `source` column.""" + __tablename__ = "stripe_events" + + id: Mapped[int] = mapped_column(_PK, primary_key=True, autoincrement=True) + event_id: Mapped[str] = mapped_column(String(128), nullable=False) + event_type: Mapped[str] = mapped_column(String(64), nullable=False) + received_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=utcnow, nullable=False, + ) + processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + error: Mapped[str | None] = mapped_column(Text) + payload: Mapped[str] = mapped_column(Text, nullable=False) + + __table_args__ = ( + UniqueConstraint("event_id", name="uq_stripe_events_event_id"), + Index("ix_stripe_events_type_received", "event_type", "received_at"), + ) diff --git a/app/routers/public.py b/app/routers/public.py index a040ccd..33bd245 100644 --- a/app/routers/public.py +++ b/app/routers/public.py @@ -15,6 +15,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse from app.auth import CurrentUser, maybe_current_user +from app.services.access import is_paid_active from app.templates_env import templates @@ -33,7 +34,9 @@ async def pricing_page( request: Request, cu: CurrentUser | None = Depends(maybe_current_user), ): - return templates.TemplateResponse(request, "pricing.html", _ctx(request, cu)) + ctx = _ctx(request, cu) + ctx["paid"] = is_paid_active(cu) + return templates.TemplateResponse(request, "pricing.html", ctx) @router.get("/about", response_class=HTMLResponse) diff --git a/app/routers/stripe_billing.py b/app/routers/stripe_billing.py new file mode 100644 index 0000000..e896f74 --- /dev/null +++ b/app/routers/stripe_billing.py @@ -0,0 +1,383 @@ +"""Stripe billing endpoints — checkout, webhook, customer portal. + +Stripe is the merchant-on-record for read.markets (after Polar/Paddle +both declined the financial-media category). We delegate payment UI to +Stripe-hosted Checkout and Customer Portal; the only state we keep on +our side is `users.stripe_customer_id` / `users.stripe_subscription_id` +so we can match incoming webhooks back to the right user. + +The Stripe SDK is sync; we wrap calls in `asyncio.to_thread` so the +event loop doesn't block while Stripe answers. For our request volume +this is more reliable than the SDK's nascent async surface. + +Routes +- POST /api/stripe/checkout — logged-in user upgrades. Body: {cadence}. +- POST /api/stripe/webhook — Stripe → us, signature-verified. +- POST /api/stripe/portal — logged-in user opens the customer portal. +""" +from __future__ import annotations + +import asyncio +import json +from typing import Any, Literal + +import stripe +from fastapi import APIRouter, Body, Depends, HTTPException, Request +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app import branding +from app.auth import CurrentUser, require_auth +from app.config import get_settings +from app.db import get_session, utcnow +from app.logging import get_logger +from app.models import StripeEvent, User + + +log = get_logger("stripe_billing") +router = APIRouter() + + +# Cap stored payload at 16 KiB so a hostile (or buggy) sender can't +# blow up a single row. Same pattern as polar_webhook. +_PAYLOAD_STORE_MAX = 16 * 1024 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _require_configured() -> None: + s = get_settings() + if not s.STRIPE_API_KEY: + raise HTTPException(status_code=503, detail="stripe not configured") + + +def _price_for(cadence: str) -> str: + s = get_settings() + if cadence == "monthly": + if not s.STRIPE_PRICE_MONTHLY: + raise HTTPException(status_code=503, detail="STRIPE_PRICE_MONTHLY not set") + return s.STRIPE_PRICE_MONTHLY + if cadence == "annual": + if not s.STRIPE_PRICE_ANNUAL: + raise HTTPException(status_code=503, detail="STRIPE_PRICE_ANNUAL not set") + return s.STRIPE_PRICE_ANNUAL + raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'") + + +def _stripe_client() -> stripe.StripeClient: + """Per-call client so we read the secret at request time (lets us + rotate the key by editing .env + reloading without rebuilding any + cached client).""" + return stripe.StripeClient(get_settings().STRIPE_API_KEY) + + +# --------------------------------------------------------------------------- +# POST /api/stripe/checkout +# --------------------------------------------------------------------------- + + +class CheckoutRequest(BaseModel): + cadence: Literal["monthly", "annual"] + + +class CheckoutResponse(BaseModel): + url: str + + +@router.post("/api/stripe/checkout", response_model=CheckoutResponse) +async def create_checkout( + body: CheckoutRequest, + session: AsyncSession = Depends(get_session), + cu: CurrentUser = Depends(require_auth), +) -> CheckoutResponse: + _require_configured() + if cu.user is None: + # Admin bearer token has no User row — they shouldn't be buying. + raise HTTPException(status_code=400, detail="admin token cannot purchase") + + user = await session.get(User, cu.user.id) + if user is None: + raise HTTPException(status_code=404, detail="user_not_found") + + price_id = _price_for(body.cadence) + client = _stripe_client() + + # Pass `customer` if we already minted one for this user (avoids + # creating duplicate Stripe customers on repeat checkouts); + # otherwise let Stripe create it via `customer_email`. + create_kwargs: dict[str, Any] = { + "mode": "subscription", + "line_items": [{"price": price_id, "quantity": 1}], + "client_reference_id": str(user.id), + "success_url": f"{branding.SITE_URL}/settings?upgraded=1", + "cancel_url": f"{branding.SITE_URL}/pricing", + # Lets us paste in a referral coupon at checkout once the + # referral redemption flow ships. + "allow_promotion_codes": True, + } + if user.stripe_customer_id: + create_kwargs["customer"] = user.stripe_customer_id + else: + create_kwargs["customer_email"] = user.email + + try: + sess = await asyncio.to_thread( + client.checkout.sessions.create, params=create_kwargs, + ) + except stripe.StripeError as e: + log.error("stripe.checkout.create_failed", user_id=user.id, error=str(e)) + raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}") + + if not sess.url: + raise HTTPException(status_code=502, detail="stripe returned no checkout URL") + log.info("stripe.checkout.created", user_id=user.id, session_id=sess.id, + cadence=body.cadence) + return CheckoutResponse(url=sess.url) + + +# --------------------------------------------------------------------------- +# POST /api/stripe/portal +# --------------------------------------------------------------------------- + + +class PortalResponse(BaseModel): + url: str + + +@router.post("/api/stripe/portal", response_model=PortalResponse) +async def create_portal_session( + session: AsyncSession = Depends(get_session), + cu: CurrentUser = Depends(require_auth), +) -> PortalResponse: + _require_configured() + if cu.user is None: + raise HTTPException(status_code=400, detail="admin token has no portal") + + user = await session.get(User, cu.user.id) + if user is None or not user.stripe_customer_id: + raise HTTPException( + status_code=404, + detail="no_stripe_customer — start a subscription first", + ) + + client = _stripe_client() + try: + portal = await asyncio.to_thread( + client.billing_portal.sessions.create, + params={ + "customer": user.stripe_customer_id, + "return_url": f"{branding.SITE_URL}/settings", + }, + ) + except stripe.StripeError as e: + log.error("stripe.portal.create_failed", user_id=user.id, error=str(e)) + raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}") + + return PortalResponse(url=portal.url) + + +# --------------------------------------------------------------------------- +# POST /api/stripe/webhook +# --------------------------------------------------------------------------- + + +async def _find_user( + session: AsyncSession, + *, + client_ref: str | None = None, + customer_id: str | None = None, +) -> User | None: + """Find the User row this event belongs to. + + `client_reference_id` is the most reliable join key — we set it + to `str(user.id)` at checkout creation. After the first event we + also know `stripe_customer_id`, which subsequent subscription / + invoice events arrive carrying.""" + if client_ref: + try: + uid = int(client_ref) + except ValueError: + uid = None + if uid is not None: + u = await session.get(User, uid) + if u is not None: + return u + if customer_id: + row = (await session.execute( + select(User).where(User.stripe_customer_id == customer_id) + )).scalar_one_or_none() + return row + return None + + +async def _grant_paid( + user: User, + *, + customer_id: str | None, + subscription_id: str | None, +) -> None: + user.tier = "paid" + if customer_id and user.stripe_customer_id != customer_id: + user.stripe_customer_id = customer_id + if subscription_id and user.stripe_subscription_id != subscription_id: + user.stripe_subscription_id = subscription_id + + +async def _revoke_paid(user: User) -> None: + user.tier = "free" + user.stripe_subscription_id = None + # Keep stripe_customer_id so a re-subscription matches this row. + + +async def _handle_checkout_completed( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + user = await _find_user( + session, + client_ref=obj.get("client_reference_id"), + customer_id=obj.get("customer"), + ) + if user is None: + log.warning("stripe.user_not_found", event=event_type) + return + await _grant_paid( + user, + customer_id=obj.get("customer"), + subscription_id=obj.get("subscription"), + ) + + +async def _handle_subscription_event( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + """customer.subscription.created / .updated — flip to paid if the + Stripe-side status says the subscription is active/trialing; drop + to free if it's an end-state.""" + user = await _find_user(session, customer_id=obj.get("customer")) + if user is None: + log.warning("stripe.user_not_found", event=event_type, + customer_id=obj.get("customer")) + return + status = obj.get("status") + # Stripe statuses: trialing, active, past_due, canceled, unpaid, + # incomplete, incomplete_expired, paused. Treat trialing/active as + # paid; everything else holds tier the same until we get an explicit + # subscription.deleted (which fires after the final state lands). + if status in ("trialing", "active"): + await _grant_paid( + user, + customer_id=obj.get("customer"), + subscription_id=obj.get("id"), + ) + + +async def _handle_subscription_deleted( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + user = await _find_user(session, customer_id=obj.get("customer")) + if user is None: + log.warning("stripe.user_not_found", event=event_type, + customer_id=obj.get("customer")) + return + await _revoke_paid(user) + + +async def _handle_audit_only( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + """invoice.paid / invoice.payment_failed / charge.refunded — we + record these in stripe_events for the audit log but the tier doesn't + move until subscription.deleted fires.""" + return None + + +_HANDLERS = { + "checkout.session.completed": _handle_checkout_completed, + "customer.subscription.created": _handle_subscription_event, + "customer.subscription.updated": _handle_subscription_event, + "customer.subscription.deleted": _handle_subscription_deleted, + "invoice.paid": _handle_audit_only, + "invoice.payment_failed": _handle_audit_only, + "charge.refunded": _handle_audit_only, +} + + +@router.post("/api/stripe/webhook") +async def stripe_webhook( + request: Request, + session: AsyncSession = Depends(get_session), +) -> dict[str, str]: + s = get_settings() + if not s.STRIPE_WEBHOOK_SECRET: + raise HTTPException(status_code=503, detail="stripe webhook not configured") + + sig = request.headers.get("stripe-signature", "") + if not sig: + raise HTTPException(status_code=400, detail="missing stripe-signature header") + + body = await request.body() + # construct_event handles HMAC verification + timestamp tolerance. + # We then re-parse the body as plain JSON for handler dispatch — + # the Stripe SDK's StripeObject doesn't expose dict.get(), and + # round-tripping through json gives us simple, typed-dict access. + try: + stripe.Webhook.construct_event( + payload=body, sig_header=sig, secret=s.STRIPE_WEBHOOK_SECRET, + ) + except stripe.SignatureVerificationError: + raise HTTPException(status_code=401, detail="bad signature") + except ValueError: + raise HTTPException(status_code=400, detail="invalid payload") + + envelope = json.loads(body) + event_id = envelope.get("id") or "" + event_type = envelope.get("type") or "unknown" + obj = (envelope.get("data") or {}).get("object") or {} + + if not event_id: + raise HTTPException(status_code=400, detail="event missing id") + + # Idempotency: insert audit row first. UNIQUE on event_id makes a + # replay of the same Stripe event id a no-op (Stripe retries on + # non-2xx, so always 2xx after first successful processing). + audit = StripeEvent( + event_id=event_id, + event_type=event_type, + received_at=utcnow(), + payload=body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX], + ) + session.add(audit) + try: + await session.flush() + except IntegrityError: + await session.rollback() + log.info("stripe.duplicate_delivery", event_id=event_id, type=event_type) + return {"status": "duplicate"} + + handler = _HANDLERS.get(event_type) + if handler is None: + audit.processed_at = utcnow() + await session.commit() + log.info("stripe.event_unhandled", type=event_type, id=event_id) + return {"status": "ignored"} + + try: + await handler(session, event_type, obj) + except Exception as e: + audit.error = str(e)[:1024] + await session.commit() + log.exception("stripe.handler_error", type=event_type, id=event_id) + # Ack 200 — we don't want Stripe retrying a handler that broke + # the same way on every delivery. An operator triages from the + # `error` column. + return {"status": "handler_error"} + + audit.processed_at = utcnow() + await session.commit() + log.info("stripe.processed", type=event_type, id=event_id) + return {"status": "ok"} diff --git a/app/templates/pricing.html b/app/templates/pricing.html index a26e106..6bb4411 100644 --- a/app/templates/pricing.html +++ b/app/templates/pricing.html @@ -50,9 +50,8 @@