From 410afe007897af61af9f3f84b935d7663d8281a6 Mon Sep 17 00:00:00 2001
From: Giorgio Gilestro
Date: Tue, 26 May 2026 18:45:13 +0200
Subject: [PATCH] stripe: wire checkout, customer portal, and webhook for
read.markets
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Stripe is the merchant-on-record for read.markets after Polar/Paddle
both declined the financial-media category. This commit lands the
full subscription flow: an "Upgrade" button on /pricing now opens a
real Stripe-hosted Checkout, completes the subscription, and the
webhook flips user.tier to "paid" idempotently.
Endpoints
- POST /api/stripe/checkout (require_auth) — creates a hosted
Checkout Session in subscription mode, passes user.id as
client_reference_id + email as customer_email, returns the URL
for the page-side JS to redirect to. Reuses an existing
stripe_customer_id to avoid duplicate Stripe customers on repeat
checkouts. allow_promotion_codes=True so the referral-credit
redemption can attach a coupon at checkout once that flow ships.
- POST /api/stripe/portal (require_auth) — mints a Stripe Customer
Portal session. Used by /settings; returns 404 until the user has
a stripe_customer_id (i.e. completed at least one checkout).
- POST /api/stripe/webhook — signature-verified via
stripe.Webhook.construct_event. Idempotent via UNIQUE on
stripe_events.event_id. Event dispatch:
checkout.session.completed → grant paid, store IDs
customer.subscription.created → grant paid (active/trialing)
customer.subscription.updated → grant paid (active/trialing)
customer.subscription.deleted → drop to free, clear sub id
invoice.paid / failed → audit only
charge.refunded → audit only
Stripe-SDK objects don't expose dict.get(); we use the SDK for
signature verification then re-parse the JSON body for handler
dispatch — cleaner than reaching into StripeObject internals.
Schema (migration 0019)
- users.stripe_customer_id, users.stripe_subscription_id (nullable
String(64), UNIQUE on customer_id).
- stripe_events table mirroring polar_events: event_id (unique),
event_type, received_at, processed_at, error, raw payload
(truncated to 16 KiB).
Settings (.env)
- STRIPE_API_KEY (rk_test_… for dev, rk_live_… for GA)
- STRIPE_WEBHOOK_SECRET (whsec_… from the dashboard endpoint)
- STRIPE_PRICE_MONTHLY (price_xxx for £7/month)
- STRIPE_PRICE_ANNUAL (price_xxx for £70/year)
Pricing page
- Free tier CTA unchanged.
- Paid CTA branches three ways: paid → "Manage subscription" to
/settings; logged-in free → two buttons (£7/mo, £70/yr) that POST
to /api/stripe/checkout and redirect; anonymous → /login?next=/pricing.
- Inline JS intercepts the button click, calls the checkout
endpoint, redirects on success, surfaces errors via alert(). No
Stripe.js dep — we use the hosted-checkout URL directly.
Polar handler stays in place for berengar.io / flyroom.net which
still ship through Polar. polar_* and stripe_* columns coexist
independently on the User row.
Tests
- 9 in tests/test_stripe_billing.py covering: bad signature → 401,
missing signature → 400, checkout.session.completed flips tier +
stores IDs, subscription.updated active grants paid,
subscription.deleted drops to free with customer id preserved,
replayed event id is no-op (one row in stripe_events),
unknown event acked 200, checkout endpoint mocks the SDK and
returns the hosted URL, checkout requires login.
- Full suite: 221 passed, 5 skipped.
Co-Authored-By: Claude Opus 4.7
---
alembic/versions/0019_stripe.py | 56 +++++
app/config.py | 10 +
app/main.py | 4 +
app/models.py | 33 +++
app/routers/public.py | 5 +-
app/routers/stripe_billing.py | 383 ++++++++++++++++++++++++++++++++
app/templates/pricing.html | 51 ++++-
pyproject.toml | 1 +
tests/test_stripe_billing.py | 322 +++++++++++++++++++++++++++
9 files changed, 858 insertions(+), 7 deletions(-)
create mode 100644 alembic/versions/0019_stripe.py
create mode 100644 app/routers/stripe_billing.py
create mode 100644 tests/test_stripe_billing.py
diff --git a/alembic/versions/0019_stripe.py b/alembic/versions/0019_stripe.py
new file mode 100644
index 0000000..3ea4018
--- /dev/null
+++ b/alembic/versions/0019_stripe.py
@@ -0,0 +1,56 @@
+"""stripe integration: users.stripe_customer_id / stripe_subscription_id,
+stripe_events table.
+
+Revision ID: 0019
+Revises: 0018
+Create Date: 2026-05-26
+"""
+from typing import Sequence, Union
+
+import sqlalchemy as sa
+from alembic import op
+
+
+revision: str = "0019"
+down_revision: Union[str, None] = "0018"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ op.add_column(
+ "users",
+ sa.Column("stripe_customer_id", sa.String(length=64), nullable=True),
+ )
+ op.add_column(
+ "users",
+ sa.Column("stripe_subscription_id", sa.String(length=64), nullable=True),
+ )
+ op.create_unique_constraint(
+ "uq_users_stripe_customer", "users", ["stripe_customer_id"],
+ )
+
+ op.create_table(
+ "stripe_events",
+ sa.Column("id", sa.BigInteger(), autoincrement=True, primary_key=True),
+ sa.Column("event_id", sa.String(length=128), nullable=False),
+ sa.Column("event_type", sa.String(length=64), nullable=False),
+ sa.Column("received_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("processed_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("error", sa.Text(), nullable=True),
+ sa.Column("payload", sa.Text(), nullable=False),
+ sa.UniqueConstraint("event_id", name="uq_stripe_events_event_id"),
+ )
+ op.create_index(
+ "ix_stripe_events_type_received",
+ "stripe_events",
+ ["event_type", "received_at"],
+ )
+
+
+def downgrade() -> None:
+ op.drop_index("ix_stripe_events_type_received", table_name="stripe_events")
+ op.drop_table("stripe_events")
+ op.drop_constraint("uq_users_stripe_customer", "users", type_="unique")
+ op.drop_column("users", "stripe_subscription_id")
+ op.drop_column("users", "stripe_customer_id")
diff --git a/app/config.py b/app/config.py
index 4ff67d6..0aabb1d 100644
--- a/app/config.py
+++ b/app/config.py
@@ -99,6 +99,16 @@ class Settings(BaseSettings):
POLAR_WEBHOOK_SECRET: str = ""
POLAR_API_KEY: str = ""
+ # Stripe (merchant-on-record for read.markets after Polar/Paddle
+ # both declined the financial-media category). Test-mode keys are
+ # `sk_test_*` / `whsec_*`; live-mode keys are `sk_live_*` — swap at
+ # GA cutover. Empty values make the corresponding endpoints 503 so
+ # a misconfig is loud rather than silently accepting unsigned events.
+ STRIPE_API_KEY: str = ""
+ STRIPE_WEBHOOK_SECRET: str = ""
+ STRIPE_PRICE_MONTHLY: str = "" # price_xxx for £7/month subscription
+ STRIPE_PRICE_ANNUAL: str = "" # price_xxx for £70/year subscription
+
# Config file locations (overridable for tests)
BASELINE_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "default.toml")
PORTFOLIO_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "portfolio.toml")
diff --git a/app/main.py b/app/main.py
index d9c491e..9499a34 100644
--- a/app/main.py
+++ b/app/main.py
@@ -23,6 +23,7 @@ from app.routers import email as email_router
from app.routers import pages as pages_router
from app.routers import polar_webhook as polar_webhook_router
from app.routers import public as public_router
+from app.routers import stripe_billing as stripe_router
from app.routers import sync as sync_router
from app.routers import universe as universe_router
from app.services.feeds_bootstrap import bootstrap_feeds
@@ -93,6 +94,9 @@ app.include_router(sync_router.router, tags=["portfolio-sync"])
# `/api/polar/webhook` is set on the route itself so the URL Polar
# stores remains stable even if api_router's prefix ever moves.
app.include_router(polar_webhook_router.router, tags=["polar-webhook"])
+# Stripe billing (checkout, portal, webhook). Auth lives per-route:
+# checkout + portal require_auth, webhook is signature-gated.
+app.include_router(stripe_router.router, tags=["stripe-billing"])
# Public router (no auth dep) before pages_router so the marketing/legal
# paths can never collide with future authenticated routes.
app.include_router(public_router.router)
diff --git a/app/models.py b/app/models.py
index 643b8d8..2140f6b 100644
--- a/app/models.py
+++ b/app/models.py
@@ -194,11 +194,18 @@ class User(Base):
# we cancel against from /settings.
polar_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
polar_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
+ # Stripe (merchant-on-record for read.markets). Populated on the
+ # first checkout.session.completed event via client_reference_id;
+ # used thereafter to match incoming subscription/invoice events
+ # back to this row.
+ stripe_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
+ stripe_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
__table_args__ = (
UniqueConstraint("email", name="uq_users_email"),
UniqueConstraint("referral_code", name="uq_users_referral_code"),
UniqueConstraint("polar_customer_id", name="uq_users_polar_customer"),
+ UniqueConstraint("stripe_customer_id", name="uq_users_stripe_customer"),
)
@@ -385,3 +392,29 @@ class PolarEvent(Base):
UniqueConstraint("event_id", name="uq_polar_events_event_id"),
Index("ix_polar_events_type_received", "event_type", "received_at"),
)
+
+
+class StripeEvent(Base):
+ """Audit + idempotency table for inbound Stripe webhook deliveries.
+
+ Same shape and purpose as PolarEvent — Stripe's `event.id` plays the
+ same role as Standard Webhooks' `webhook-id`. We keep the tables
+ distinct (rather than a single 'webhook_events' table) so an
+ operator can look at the audit trail per processor without filtering
+ on a `source` column."""
+ __tablename__ = "stripe_events"
+
+ id: Mapped[int] = mapped_column(_PK, primary_key=True, autoincrement=True)
+ event_id: Mapped[str] = mapped_column(String(128), nullable=False)
+ event_type: Mapped[str] = mapped_column(String(64), nullable=False)
+ received_at: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), default=utcnow, nullable=False,
+ )
+ processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
+ error: Mapped[str | None] = mapped_column(Text)
+ payload: Mapped[str] = mapped_column(Text, nullable=False)
+
+ __table_args__ = (
+ UniqueConstraint("event_id", name="uq_stripe_events_event_id"),
+ Index("ix_stripe_events_type_received", "event_type", "received_at"),
+ )
diff --git a/app/routers/public.py b/app/routers/public.py
index a040ccd..33bd245 100644
--- a/app/routers/public.py
+++ b/app/routers/public.py
@@ -15,6 +15,7 @@ from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from app.auth import CurrentUser, maybe_current_user
+from app.services.access import is_paid_active
from app.templates_env import templates
@@ -33,7 +34,9 @@ async def pricing_page(
request: Request,
cu: CurrentUser | None = Depends(maybe_current_user),
):
- return templates.TemplateResponse(request, "pricing.html", _ctx(request, cu))
+ ctx = _ctx(request, cu)
+ ctx["paid"] = is_paid_active(cu)
+ return templates.TemplateResponse(request, "pricing.html", ctx)
@router.get("/about", response_class=HTMLResponse)
diff --git a/app/routers/stripe_billing.py b/app/routers/stripe_billing.py
new file mode 100644
index 0000000..e896f74
--- /dev/null
+++ b/app/routers/stripe_billing.py
@@ -0,0 +1,383 @@
+"""Stripe billing endpoints — checkout, webhook, customer portal.
+
+Stripe is the merchant-on-record for read.markets (after Polar/Paddle
+both declined the financial-media category). We delegate payment UI to
+Stripe-hosted Checkout and Customer Portal; the only state we keep on
+our side is `users.stripe_customer_id` / `users.stripe_subscription_id`
+so we can match incoming webhooks back to the right user.
+
+The Stripe SDK is sync; we wrap calls in `asyncio.to_thread` so the
+event loop doesn't block while Stripe answers. For our request volume
+this is more reliable than the SDK's nascent async surface.
+
+Routes
+- POST /api/stripe/checkout — logged-in user upgrades. Body: {cadence}.
+- POST /api/stripe/webhook — Stripe → us, signature-verified.
+- POST /api/stripe/portal — logged-in user opens the customer portal.
+"""
+from __future__ import annotations
+
+import asyncio
+import json
+from typing import Any, Literal
+
+import stripe
+from fastapi import APIRouter, Body, Depends, HTTPException, Request
+from pydantic import BaseModel
+from sqlalchemy import select
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app import branding
+from app.auth import CurrentUser, require_auth
+from app.config import get_settings
+from app.db import get_session, utcnow
+from app.logging import get_logger
+from app.models import StripeEvent, User
+
+
+log = get_logger("stripe_billing")
+router = APIRouter()
+
+
+# Cap stored payload at 16 KiB so a hostile (or buggy) sender can't
+# blow up a single row. Same pattern as polar_webhook.
+_PAYLOAD_STORE_MAX = 16 * 1024
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _require_configured() -> None:
+ s = get_settings()
+ if not s.STRIPE_API_KEY:
+ raise HTTPException(status_code=503, detail="stripe not configured")
+
+
+def _price_for(cadence: str) -> str:
+ s = get_settings()
+ if cadence == "monthly":
+ if not s.STRIPE_PRICE_MONTHLY:
+ raise HTTPException(status_code=503, detail="STRIPE_PRICE_MONTHLY not set")
+ return s.STRIPE_PRICE_MONTHLY
+ if cadence == "annual":
+ if not s.STRIPE_PRICE_ANNUAL:
+ raise HTTPException(status_code=503, detail="STRIPE_PRICE_ANNUAL not set")
+ return s.STRIPE_PRICE_ANNUAL
+ raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'")
+
+
+def _stripe_client() -> stripe.StripeClient:
+ """Per-call client so we read the secret at request time (lets us
+ rotate the key by editing .env + reloading without rebuilding any
+ cached client)."""
+ return stripe.StripeClient(get_settings().STRIPE_API_KEY)
+
+
+# ---------------------------------------------------------------------------
+# POST /api/stripe/checkout
+# ---------------------------------------------------------------------------
+
+
+class CheckoutRequest(BaseModel):
+ cadence: Literal["monthly", "annual"]
+
+
+class CheckoutResponse(BaseModel):
+ url: str
+
+
+@router.post("/api/stripe/checkout", response_model=CheckoutResponse)
+async def create_checkout(
+ body: CheckoutRequest,
+ session: AsyncSession = Depends(get_session),
+ cu: CurrentUser = Depends(require_auth),
+) -> CheckoutResponse:
+ _require_configured()
+ if cu.user is None:
+ # Admin bearer token has no User row — they shouldn't be buying.
+ raise HTTPException(status_code=400, detail="admin token cannot purchase")
+
+ user = await session.get(User, cu.user.id)
+ if user is None:
+ raise HTTPException(status_code=404, detail="user_not_found")
+
+ price_id = _price_for(body.cadence)
+ client = _stripe_client()
+
+ # Pass `customer` if we already minted one for this user (avoids
+ # creating duplicate Stripe customers on repeat checkouts);
+ # otherwise let Stripe create it via `customer_email`.
+ create_kwargs: dict[str, Any] = {
+ "mode": "subscription",
+ "line_items": [{"price": price_id, "quantity": 1}],
+ "client_reference_id": str(user.id),
+ "success_url": f"{branding.SITE_URL}/settings?upgraded=1",
+ "cancel_url": f"{branding.SITE_URL}/pricing",
+ # Lets us paste in a referral coupon at checkout once the
+ # referral redemption flow ships.
+ "allow_promotion_codes": True,
+ }
+ if user.stripe_customer_id:
+ create_kwargs["customer"] = user.stripe_customer_id
+ else:
+ create_kwargs["customer_email"] = user.email
+
+ try:
+ sess = await asyncio.to_thread(
+ client.checkout.sessions.create, params=create_kwargs,
+ )
+ except stripe.StripeError as e:
+ log.error("stripe.checkout.create_failed", user_id=user.id, error=str(e))
+ raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
+
+ if not sess.url:
+ raise HTTPException(status_code=502, detail="stripe returned no checkout URL")
+ log.info("stripe.checkout.created", user_id=user.id, session_id=sess.id,
+ cadence=body.cadence)
+ return CheckoutResponse(url=sess.url)
+
+
+# ---------------------------------------------------------------------------
+# POST /api/stripe/portal
+# ---------------------------------------------------------------------------
+
+
+class PortalResponse(BaseModel):
+ url: str
+
+
+@router.post("/api/stripe/portal", response_model=PortalResponse)
+async def create_portal_session(
+ session: AsyncSession = Depends(get_session),
+ cu: CurrentUser = Depends(require_auth),
+) -> PortalResponse:
+ _require_configured()
+ if cu.user is None:
+ raise HTTPException(status_code=400, detail="admin token has no portal")
+
+ user = await session.get(User, cu.user.id)
+ if user is None or not user.stripe_customer_id:
+ raise HTTPException(
+ status_code=404,
+ detail="no_stripe_customer — start a subscription first",
+ )
+
+ client = _stripe_client()
+ try:
+ portal = await asyncio.to_thread(
+ client.billing_portal.sessions.create,
+ params={
+ "customer": user.stripe_customer_id,
+ "return_url": f"{branding.SITE_URL}/settings",
+ },
+ )
+ except stripe.StripeError as e:
+ log.error("stripe.portal.create_failed", user_id=user.id, error=str(e))
+ raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
+
+ return PortalResponse(url=portal.url)
+
+
+# ---------------------------------------------------------------------------
+# POST /api/stripe/webhook
+# ---------------------------------------------------------------------------
+
+
+async def _find_user(
+ session: AsyncSession,
+ *,
+ client_ref: str | None = None,
+ customer_id: str | None = None,
+) -> User | None:
+ """Find the User row this event belongs to.
+
+ `client_reference_id` is the most reliable join key — we set it
+ to `str(user.id)` at checkout creation. After the first event we
+ also know `stripe_customer_id`, which subsequent subscription /
+ invoice events arrive carrying."""
+ if client_ref:
+ try:
+ uid = int(client_ref)
+ except ValueError:
+ uid = None
+ if uid is not None:
+ u = await session.get(User, uid)
+ if u is not None:
+ return u
+ if customer_id:
+ row = (await session.execute(
+ select(User).where(User.stripe_customer_id == customer_id)
+ )).scalar_one_or_none()
+ return row
+ return None
+
+
+async def _grant_paid(
+ user: User,
+ *,
+ customer_id: str | None,
+ subscription_id: str | None,
+) -> None:
+ user.tier = "paid"
+ if customer_id and user.stripe_customer_id != customer_id:
+ user.stripe_customer_id = customer_id
+ if subscription_id and user.stripe_subscription_id != subscription_id:
+ user.stripe_subscription_id = subscription_id
+
+
+async def _revoke_paid(user: User) -> None:
+ user.tier = "free"
+ user.stripe_subscription_id = None
+ # Keep stripe_customer_id so a re-subscription matches this row.
+
+
+async def _handle_checkout_completed(
+ session: AsyncSession, event_type: str, obj: dict[str, Any],
+) -> None:
+ user = await _find_user(
+ session,
+ client_ref=obj.get("client_reference_id"),
+ customer_id=obj.get("customer"),
+ )
+ if user is None:
+ log.warning("stripe.user_not_found", event=event_type)
+ return
+ await _grant_paid(
+ user,
+ customer_id=obj.get("customer"),
+ subscription_id=obj.get("subscription"),
+ )
+
+
+async def _handle_subscription_event(
+ session: AsyncSession, event_type: str, obj: dict[str, Any],
+) -> None:
+ """customer.subscription.created / .updated — flip to paid if the
+ Stripe-side status says the subscription is active/trialing; drop
+ to free if it's an end-state."""
+ user = await _find_user(session, customer_id=obj.get("customer"))
+ if user is None:
+ log.warning("stripe.user_not_found", event=event_type,
+ customer_id=obj.get("customer"))
+ return
+ status = obj.get("status")
+ # Stripe statuses: trialing, active, past_due, canceled, unpaid,
+ # incomplete, incomplete_expired, paused. Treat trialing/active as
+ # paid; everything else holds tier the same until we get an explicit
+ # subscription.deleted (which fires after the final state lands).
+ if status in ("trialing", "active"):
+ await _grant_paid(
+ user,
+ customer_id=obj.get("customer"),
+ subscription_id=obj.get("id"),
+ )
+
+
+async def _handle_subscription_deleted(
+ session: AsyncSession, event_type: str, obj: dict[str, Any],
+) -> None:
+ user = await _find_user(session, customer_id=obj.get("customer"))
+ if user is None:
+ log.warning("stripe.user_not_found", event=event_type,
+ customer_id=obj.get("customer"))
+ return
+ await _revoke_paid(user)
+
+
+async def _handle_audit_only(
+ session: AsyncSession, event_type: str, obj: dict[str, Any],
+) -> None:
+ """invoice.paid / invoice.payment_failed / charge.refunded — we
+ record these in stripe_events for the audit log but the tier doesn't
+ move until subscription.deleted fires."""
+ return None
+
+
+_HANDLERS = {
+ "checkout.session.completed": _handle_checkout_completed,
+ "customer.subscription.created": _handle_subscription_event,
+ "customer.subscription.updated": _handle_subscription_event,
+ "customer.subscription.deleted": _handle_subscription_deleted,
+ "invoice.paid": _handle_audit_only,
+ "invoice.payment_failed": _handle_audit_only,
+ "charge.refunded": _handle_audit_only,
+}
+
+
+@router.post("/api/stripe/webhook")
+async def stripe_webhook(
+ request: Request,
+ session: AsyncSession = Depends(get_session),
+) -> dict[str, str]:
+ s = get_settings()
+ if not s.STRIPE_WEBHOOK_SECRET:
+ raise HTTPException(status_code=503, detail="stripe webhook not configured")
+
+ sig = request.headers.get("stripe-signature", "")
+ if not sig:
+ raise HTTPException(status_code=400, detail="missing stripe-signature header")
+
+ body = await request.body()
+ # construct_event handles HMAC verification + timestamp tolerance.
+ # We then re-parse the body as plain JSON for handler dispatch —
+ # the Stripe SDK's StripeObject doesn't expose dict.get(), and
+ # round-tripping through json gives us simple, typed-dict access.
+ try:
+ stripe.Webhook.construct_event(
+ payload=body, sig_header=sig, secret=s.STRIPE_WEBHOOK_SECRET,
+ )
+ except stripe.SignatureVerificationError:
+ raise HTTPException(status_code=401, detail="bad signature")
+ except ValueError:
+ raise HTTPException(status_code=400, detail="invalid payload")
+
+ envelope = json.loads(body)
+ event_id = envelope.get("id") or ""
+ event_type = envelope.get("type") or "unknown"
+ obj = (envelope.get("data") or {}).get("object") or {}
+
+ if not event_id:
+ raise HTTPException(status_code=400, detail="event missing id")
+
+ # Idempotency: insert audit row first. UNIQUE on event_id makes a
+ # replay of the same Stripe event id a no-op (Stripe retries on
+ # non-2xx, so always 2xx after first successful processing).
+ audit = StripeEvent(
+ event_id=event_id,
+ event_type=event_type,
+ received_at=utcnow(),
+ payload=body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX],
+ )
+ session.add(audit)
+ try:
+ await session.flush()
+ except IntegrityError:
+ await session.rollback()
+ log.info("stripe.duplicate_delivery", event_id=event_id, type=event_type)
+ return {"status": "duplicate"}
+
+ handler = _HANDLERS.get(event_type)
+ if handler is None:
+ audit.processed_at = utcnow()
+ await session.commit()
+ log.info("stripe.event_unhandled", type=event_type, id=event_id)
+ return {"status": "ignored"}
+
+ try:
+ await handler(session, event_type, obj)
+ except Exception as e:
+ audit.error = str(e)[:1024]
+ await session.commit()
+ log.exception("stripe.handler_error", type=event_type, id=event_id)
+ # Ack 200 — we don't want Stripe retrying a handler that broke
+ # the same way on every delivery. An operator triages from the
+ # `error` column.
+ return {"status": "handler_error"}
+
+ audit.processed_at = utcnow()
+ await session.commit()
+ log.info("stripe.processed", type=event_type, id=event_id)
+ return {"status": "ok"}
diff --git a/app/templates/pricing.html b/app/templates/pricing.html
index a26e106..6bb4411 100644
--- a/app/templates/pricing.html
+++ b/app/templates/pricing.html
@@ -50,9 +50,8 @@
Full-day news feed, hourly strategic log, follow-up chat, and AI portfolio analysis.
£7 / month
- Or £70 / year — two months free. Prices
- in GBP, VAT where applicable. Checkout opens with the payments
- rollout.
+ Or £70 / year — two months free.
+ Prices in GBP, VAT where applicable.
Everything in Free, plus
@@ -72,16 +71,56 @@
or a personal recommendation under FSMA / FCA COBS.
+
+
Free vs Paid at a glance
diff --git a/pyproject.toml b/pyproject.toml
index cfa65d6..5a0b50f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,6 +23,7 @@ dependencies = [
"email-validator>=2.2",
"aiosmtplib>=3.0",
"redis[hiredis]>=5.2",
+ "stripe>=11.0",
]
[project.optional-dependencies]
diff --git a/tests/test_stripe_billing.py b/tests/test_stripe_billing.py
new file mode 100644
index 0000000..61bf410
--- /dev/null
+++ b/tests/test_stripe_billing.py
@@ -0,0 +1,322 @@
+"""Stripe billing endpoints: signature verification, idempotency, tier
+flips, and checkout creation.
+
+Same integration-style scaffold as test_polar_webhook.py — real router
+over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal
+sessions.create) are mocked so the suite never makes a real HTTP call.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import hmac
+import json
+import time
+from types import SimpleNamespace
+from unittest.mock import patch
+
+import pytest
+
+
+_API_KEY = "sk_test_dummy_for_unit_tests"
+_WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests"
+_PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx"
+_PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx"
+
+
+def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str:
+ """Produce a Stripe-Signature header matching the bytes signed.
+ Format: `t=,v1=` over `.`."""
+ ts = ts if ts is not None else int(time.time())
+ signed = f"{ts}.{body.decode('utf-8')}"
+ mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"),
+ hashlib.sha256).hexdigest()
+ return f"t={ts},v1={mac}"
+
+
+def _build_app(tmp_path):
+ from fastapi import FastAPI
+ from fastapi.testclient import TestClient
+ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
+
+ from app import db as db_mod
+ from app.auth import sign_session
+ from app.config import get_settings
+ from app.db import Base
+ from app.models import User
+ from app.routers import stripe_billing as stripe_router
+
+ s = get_settings()
+ s.STRIPE_API_KEY = _API_KEY # type: ignore[misc]
+ s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc]
+ s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc]
+ s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc]
+
+ engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db")
+ factory = async_sessionmaker(engine, expire_on_commit=False)
+ db_mod._engine = engine
+ db_mod._session_factory = factory
+
+ async def _seed():
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+ async with factory() as session:
+ session.add(User(id=1, email="buyer@x", tier="free"))
+ await session.commit()
+
+ asyncio.run(_seed())
+
+ app = FastAPI()
+ app.include_router(stripe_router.router)
+ return TestClient(app), factory, sign_session(1)
+
+
+def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET,
+ sig: str | None = None):
+ # Stripe's SDK requires a top-level `object: "event"` field to know
+ # this is a v1 webhook envelope — tests that omit it fail in
+ # construct_event before the signature check matters. We inject the
+ # default here so individual tests can stay terse.
+ body.setdefault("object", "event")
+ raw = json.dumps(body).encode("utf-8")
+ sig = sig if sig is not None else _stripe_sig(raw, secret)
+ return client.post(
+ "/api/stripe/webhook",
+ content=raw,
+ headers={"stripe-signature": sig, "content-type": "application/json"},
+ )
+
+
+# --- signature gate --------------------------------------------------------
+
+
+def test_webhook_rejects_bad_signature(tmp_path):
+ client, _, _ = _build_app(tmp_path)
+ raw = json.dumps({"id": "evt_x", "type": "invoice.paid",
+ "data": {"object": {}}}).encode("utf-8")
+ r = client.post(
+ "/api/stripe/webhook",
+ content=raw,
+ headers={
+ "stripe-signature": "t=0,v1=deadbeef",
+ "content-type": "application/json",
+ },
+ )
+ assert r.status_code == 401, r.text
+
+
+def test_webhook_rejects_missing_signature(tmp_path):
+ client, _, _ = _build_app(tmp_path)
+ r = client.post(
+ "/api/stripe/webhook",
+ content=b"{}",
+ headers={"content-type": "application/json"},
+ )
+ assert r.status_code == 400, r.text
+
+
+# --- happy paths -----------------------------------------------------------
+
+
+def test_checkout_session_completed_flips_tier_to_paid(tmp_path):
+ client, factory, _ = _build_app(tmp_path)
+ body = {
+ "id": "evt_checkout_1",
+ "type": "checkout.session.completed",
+ "data": {
+ "object": {
+ "client_reference_id": "1",
+ "customer": "cus_abc",
+ "subscription": "sub_xyz",
+ }
+ },
+ }
+ r = _post_webhook(client, body=body)
+ assert r.status_code == 200, r.text
+ assert r.json()["status"] == "ok"
+
+ async def _check():
+ from sqlalchemy import select
+ from app.models import User
+ async with factory() as session:
+ u = (await session.execute(
+ select(User).where(User.id == 1)
+ )).scalar_one()
+ return u.tier, u.stripe_customer_id, u.stripe_subscription_id
+
+ tier, cid, sid = asyncio.run(_check())
+ assert tier == "paid"
+ assert cid == "cus_abc"
+ assert sid == "sub_xyz"
+
+
+def test_subscription_deleted_drops_tier_to_free(tmp_path):
+ client, factory, _ = _build_app(tmp_path)
+ # First, activate.
+ _post_webhook(client, body={
+ "id": "evt_act",
+ "type": "checkout.session.completed",
+ "data": {"object": {
+ "client_reference_id": "1",
+ "customer": "cus_abc",
+ "subscription": "sub_xyz",
+ }},
+ })
+ # Then, delete the subscription.
+ r = _post_webhook(client, body={
+ "id": "evt_del",
+ "type": "customer.subscription.deleted",
+ "data": {"object": {
+ "id": "sub_xyz",
+ "customer": "cus_abc",
+ "status": "canceled",
+ }},
+ })
+ assert r.status_code == 200, r.text
+
+ async def _check():
+ from sqlalchemy import select
+ from app.models import User
+ async with factory() as session:
+ u = (await session.execute(
+ select(User).where(User.id == 1)
+ )).scalar_one()
+ return u.tier, u.stripe_customer_id, u.stripe_subscription_id
+
+ tier, cid, sid = asyncio.run(_check())
+ assert tier == "free"
+ # Customer linkage preserved so a future resub matches this row.
+ assert cid == "cus_abc"
+ assert sid is None
+
+
+def test_subscription_active_grants_paid(tmp_path):
+ """customer.subscription.updated with status=active should also
+ grant paid — covers the case where checkout.session.completed
+ arrives after subscription.created and we want either to work."""
+ client, factory, _ = _build_app(tmp_path)
+ # Seed the linkage first via checkout (so customer_id is known).
+ _post_webhook(client, body={
+ "id": "evt_ck",
+ "type": "checkout.session.completed",
+ "data": {"object": {
+ "client_reference_id": "1",
+ "customer": "cus_abc",
+ "subscription": "sub_xyz",
+ }},
+ })
+ # Drop to free manually so we can prove the updated event re-grants.
+ async def _reset():
+ from sqlalchemy import update
+ from app.models import User
+ async with factory() as session:
+ await session.execute(
+ update(User).where(User.id == 1).values(tier="free")
+ )
+ await session.commit()
+ asyncio.run(_reset())
+
+ r = _post_webhook(client, body={
+ "id": "evt_upd",
+ "type": "customer.subscription.updated",
+ "data": {"object": {
+ "id": "sub_xyz",
+ "customer": "cus_abc",
+ "status": "active",
+ }},
+ })
+ assert r.status_code == 200
+
+ async def _check_tier():
+ from sqlalchemy import select
+ from app.models import User
+ async with factory() as session:
+ return (await session.execute(
+ select(User.tier).where(User.id == 1)
+ )).scalar_one()
+ assert asyncio.run(_check_tier()) == "paid"
+
+
+# --- idempotency + unknown ------------------------------------------------
+
+
+def test_replayed_event_id_is_a_noop(tmp_path):
+ client, factory, _ = _build_app(tmp_path)
+ body = {
+ "id": "evt_dup",
+ "type": "checkout.session.completed",
+ "data": {"object": {
+ "client_reference_id": "1",
+ "customer": "cus_abc",
+ "subscription": "sub_xyz",
+ }},
+ }
+ r1 = _post_webhook(client, body=body)
+ r2 = _post_webhook(client, body=body)
+ assert r1.json()["status"] == "ok"
+ assert r2.json()["status"] == "duplicate"
+
+ async def _count_rows():
+ from sqlalchemy import select, func
+ from app.models import StripeEvent
+ async with factory() as session:
+ n = (await session.execute(
+ select(func.count(StripeEvent.id))
+ .where(StripeEvent.event_id == "evt_dup")
+ )).scalar_one()
+ return n
+ assert asyncio.run(_count_rows()) == 1
+
+
+def test_unknown_event_is_acked(tmp_path):
+ client, _, _ = _build_app(tmp_path)
+ r = _post_webhook(client, body={
+ "id": "evt_unknown",
+ "type": "product.something.new",
+ "data": {"object": {}},
+ })
+ assert r.status_code == 200
+ assert r.json()["status"] == "ignored"
+
+
+# --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------
+
+
+def test_checkout_endpoint_creates_session_and_returns_url(tmp_path):
+ client, _, session_cookie = _build_app(tmp_path)
+ # Mock the Stripe SDK call so no real HTTP goes out.
+ fake_session = SimpleNamespace(
+ id="cs_test_123", url="https://checkout.stripe.com/test",
+ )
+
+ class _FakeSessions:
+ @staticmethod
+ def create(params): # noqa: ANN001
+ assert params["mode"] == "subscription"
+ assert params["line_items"][0]["price"] == _PRICE_MONTHLY
+ assert params["client_reference_id"] == "1"
+ assert params["customer_email"] == "buyer@x"
+ return fake_session
+
+ class _FakeCheckout:
+ sessions = _FakeSessions()
+
+ class _FakeClient:
+ checkout = _FakeCheckout()
+
+ with patch("app.routers.stripe_billing._stripe_client",
+ return_value=_FakeClient()):
+ r = client.post(
+ "/api/stripe/checkout",
+ json={"cadence": "monthly"},
+ cookies={"cassandra_session": session_cookie},
+ )
+ assert r.status_code == 200, r.text
+ assert r.json()["url"] == "https://checkout.stripe.com/test"
+
+
+def test_checkout_endpoint_requires_login(tmp_path):
+ client, _, _ = _build_app(tmp_path)
+ r = client.post("/api/stripe/checkout", json={"cadence": "monthly"})
+ # No session cookie → require_auth bounces with 401.
+ assert r.status_code == 401, r.text