From 410afe007897af61af9f3f84b935d7663d8281a6 Mon Sep 17 00:00:00 2001 From: Giorgio Gilestro Date: Tue, 26 May 2026 18:45:13 +0200 Subject: [PATCH] stripe: wire checkout, customer portal, and webhook for read.markets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stripe is the merchant-on-record for read.markets after Polar/Paddle both declined the financial-media category. This commit lands the full subscription flow: an "Upgrade" button on /pricing now opens a real Stripe-hosted Checkout, completes the subscription, and the webhook flips user.tier to "paid" idempotently. Endpoints - POST /api/stripe/checkout (require_auth) — creates a hosted Checkout Session in subscription mode, passes user.id as client_reference_id + email as customer_email, returns the URL for the page-side JS to redirect to. Reuses an existing stripe_customer_id to avoid duplicate Stripe customers on repeat checkouts. allow_promotion_codes=True so the referral-credit redemption can attach a coupon at checkout once that flow ships. - POST /api/stripe/portal (require_auth) — mints a Stripe Customer Portal session. Used by /settings; returns 404 until the user has a stripe_customer_id (i.e. completed at least one checkout). - POST /api/stripe/webhook — signature-verified via stripe.Webhook.construct_event. Idempotent via UNIQUE on stripe_events.event_id. Event dispatch: checkout.session.completed → grant paid, store IDs customer.subscription.created → grant paid (active/trialing) customer.subscription.updated → grant paid (active/trialing) customer.subscription.deleted → drop to free, clear sub id invoice.paid / failed → audit only charge.refunded → audit only Stripe-SDK objects don't expose dict.get(); we use the SDK for signature verification then re-parse the JSON body for handler dispatch — cleaner than reaching into StripeObject internals. Schema (migration 0019) - users.stripe_customer_id, users.stripe_subscription_id (nullable String(64), UNIQUE on customer_id). - stripe_events table mirroring polar_events: event_id (unique), event_type, received_at, processed_at, error, raw payload (truncated to 16 KiB). Settings (.env) - STRIPE_API_KEY (rk_test_… for dev, rk_live_… for GA) - STRIPE_WEBHOOK_SECRET (whsec_… from the dashboard endpoint) - STRIPE_PRICE_MONTHLY (price_xxx for £7/month) - STRIPE_PRICE_ANNUAL (price_xxx for £70/year) Pricing page - Free tier CTA unchanged. - Paid CTA branches three ways: paid → "Manage subscription" to /settings; logged-in free → two buttons (£7/mo, £70/yr) that POST to /api/stripe/checkout and redirect; anonymous → /login?next=/pricing. - Inline JS intercepts the button click, calls the checkout endpoint, redirects on success, surfaces errors via alert(). No Stripe.js dep — we use the hosted-checkout URL directly. Polar handler stays in place for berengar.io / flyroom.net which still ship through Polar. polar_* and stripe_* columns coexist independently on the User row. Tests - 9 in tests/test_stripe_billing.py covering: bad signature → 401, missing signature → 400, checkout.session.completed flips tier + stores IDs, subscription.updated active grants paid, subscription.deleted drops to free with customer id preserved, replayed event id is no-op (one row in stripe_events), unknown event acked 200, checkout endpoint mocks the SDK and returns the hosted URL, checkout requires login. - Full suite: 221 passed, 5 skipped. Co-Authored-By: Claude Opus 4.7 --- alembic/versions/0019_stripe.py | 56 +++++ app/config.py | 10 + app/main.py | 4 + app/models.py | 33 +++ app/routers/public.py | 5 +- app/routers/stripe_billing.py | 383 ++++++++++++++++++++++++++++++++ app/templates/pricing.html | 51 ++++- pyproject.toml | 1 + tests/test_stripe_billing.py | 322 +++++++++++++++++++++++++++ 9 files changed, 858 insertions(+), 7 deletions(-) create mode 100644 alembic/versions/0019_stripe.py create mode 100644 app/routers/stripe_billing.py create mode 100644 tests/test_stripe_billing.py diff --git a/alembic/versions/0019_stripe.py b/alembic/versions/0019_stripe.py new file mode 100644 index 0000000..3ea4018 --- /dev/null +++ b/alembic/versions/0019_stripe.py @@ -0,0 +1,56 @@ +"""stripe integration: users.stripe_customer_id / stripe_subscription_id, +stripe_events table. + +Revision ID: 0019 +Revises: 0018 +Create Date: 2026-05-26 +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +revision: str = "0019" +down_revision: Union[str, None] = "0018" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("stripe_customer_id", sa.String(length=64), nullable=True), + ) + op.add_column( + "users", + sa.Column("stripe_subscription_id", sa.String(length=64), nullable=True), + ) + op.create_unique_constraint( + "uq_users_stripe_customer", "users", ["stripe_customer_id"], + ) + + op.create_table( + "stripe_events", + sa.Column("id", sa.BigInteger(), autoincrement=True, primary_key=True), + sa.Column("event_id", sa.String(length=128), nullable=False), + sa.Column("event_type", sa.String(length=64), nullable=False), + sa.Column("received_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("processed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("error", sa.Text(), nullable=True), + sa.Column("payload", sa.Text(), nullable=False), + sa.UniqueConstraint("event_id", name="uq_stripe_events_event_id"), + ) + op.create_index( + "ix_stripe_events_type_received", + "stripe_events", + ["event_type", "received_at"], + ) + + +def downgrade() -> None: + op.drop_index("ix_stripe_events_type_received", table_name="stripe_events") + op.drop_table("stripe_events") + op.drop_constraint("uq_users_stripe_customer", "users", type_="unique") + op.drop_column("users", "stripe_subscription_id") + op.drop_column("users", "stripe_customer_id") diff --git a/app/config.py b/app/config.py index 4ff67d6..0aabb1d 100644 --- a/app/config.py +++ b/app/config.py @@ -99,6 +99,16 @@ class Settings(BaseSettings): POLAR_WEBHOOK_SECRET: str = "" POLAR_API_KEY: str = "" + # Stripe (merchant-on-record for read.markets after Polar/Paddle + # both declined the financial-media category). Test-mode keys are + # `sk_test_*` / `whsec_*`; live-mode keys are `sk_live_*` — swap at + # GA cutover. Empty values make the corresponding endpoints 503 so + # a misconfig is loud rather than silently accepting unsigned events. + STRIPE_API_KEY: str = "" + STRIPE_WEBHOOK_SECRET: str = "" + STRIPE_PRICE_MONTHLY: str = "" # price_xxx for £7/month subscription + STRIPE_PRICE_ANNUAL: str = "" # price_xxx for £70/year subscription + # Config file locations (overridable for tests) BASELINE_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "default.toml") PORTFOLIO_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "portfolio.toml") diff --git a/app/main.py b/app/main.py index d9c491e..9499a34 100644 --- a/app/main.py +++ b/app/main.py @@ -23,6 +23,7 @@ from app.routers import email as email_router from app.routers import pages as pages_router from app.routers import polar_webhook as polar_webhook_router from app.routers import public as public_router +from app.routers import stripe_billing as stripe_router from app.routers import sync as sync_router from app.routers import universe as universe_router from app.services.feeds_bootstrap import bootstrap_feeds @@ -93,6 +94,9 @@ app.include_router(sync_router.router, tags=["portfolio-sync"]) # `/api/polar/webhook` is set on the route itself so the URL Polar # stores remains stable even if api_router's prefix ever moves. app.include_router(polar_webhook_router.router, tags=["polar-webhook"]) +# Stripe billing (checkout, portal, webhook). Auth lives per-route: +# checkout + portal require_auth, webhook is signature-gated. +app.include_router(stripe_router.router, tags=["stripe-billing"]) # Public router (no auth dep) before pages_router so the marketing/legal # paths can never collide with future authenticated routes. app.include_router(public_router.router) diff --git a/app/models.py b/app/models.py index 643b8d8..2140f6b 100644 --- a/app/models.py +++ b/app/models.py @@ -194,11 +194,18 @@ class User(Base): # we cancel against from /settings. polar_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True) polar_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True) + # Stripe (merchant-on-record for read.markets). Populated on the + # first checkout.session.completed event via client_reference_id; + # used thereafter to match incoming subscription/invoice events + # back to this row. + stripe_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True) + stripe_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True) __table_args__ = ( UniqueConstraint("email", name="uq_users_email"), UniqueConstraint("referral_code", name="uq_users_referral_code"), UniqueConstraint("polar_customer_id", name="uq_users_polar_customer"), + UniqueConstraint("stripe_customer_id", name="uq_users_stripe_customer"), ) @@ -385,3 +392,29 @@ class PolarEvent(Base): UniqueConstraint("event_id", name="uq_polar_events_event_id"), Index("ix_polar_events_type_received", "event_type", "received_at"), ) + + +class StripeEvent(Base): + """Audit + idempotency table for inbound Stripe webhook deliveries. + + Same shape and purpose as PolarEvent — Stripe's `event.id` plays the + same role as Standard Webhooks' `webhook-id`. We keep the tables + distinct (rather than a single 'webhook_events' table) so an + operator can look at the audit trail per processor without filtering + on a `source` column.""" + __tablename__ = "stripe_events" + + id: Mapped[int] = mapped_column(_PK, primary_key=True, autoincrement=True) + event_id: Mapped[str] = mapped_column(String(128), nullable=False) + event_type: Mapped[str] = mapped_column(String(64), nullable=False) + received_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=utcnow, nullable=False, + ) + processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + error: Mapped[str | None] = mapped_column(Text) + payload: Mapped[str] = mapped_column(Text, nullable=False) + + __table_args__ = ( + UniqueConstraint("event_id", name="uq_stripe_events_event_id"), + Index("ix_stripe_events_type_received", "event_type", "received_at"), + ) diff --git a/app/routers/public.py b/app/routers/public.py index a040ccd..33bd245 100644 --- a/app/routers/public.py +++ b/app/routers/public.py @@ -15,6 +15,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse from app.auth import CurrentUser, maybe_current_user +from app.services.access import is_paid_active from app.templates_env import templates @@ -33,7 +34,9 @@ async def pricing_page( request: Request, cu: CurrentUser | None = Depends(maybe_current_user), ): - return templates.TemplateResponse(request, "pricing.html", _ctx(request, cu)) + ctx = _ctx(request, cu) + ctx["paid"] = is_paid_active(cu) + return templates.TemplateResponse(request, "pricing.html", ctx) @router.get("/about", response_class=HTMLResponse) diff --git a/app/routers/stripe_billing.py b/app/routers/stripe_billing.py new file mode 100644 index 0000000..e896f74 --- /dev/null +++ b/app/routers/stripe_billing.py @@ -0,0 +1,383 @@ +"""Stripe billing endpoints — checkout, webhook, customer portal. + +Stripe is the merchant-on-record for read.markets (after Polar/Paddle +both declined the financial-media category). We delegate payment UI to +Stripe-hosted Checkout and Customer Portal; the only state we keep on +our side is `users.stripe_customer_id` / `users.stripe_subscription_id` +so we can match incoming webhooks back to the right user. + +The Stripe SDK is sync; we wrap calls in `asyncio.to_thread` so the +event loop doesn't block while Stripe answers. For our request volume +this is more reliable than the SDK's nascent async surface. + +Routes +- POST /api/stripe/checkout — logged-in user upgrades. Body: {cadence}. +- POST /api/stripe/webhook — Stripe → us, signature-verified. +- POST /api/stripe/portal — logged-in user opens the customer portal. +""" +from __future__ import annotations + +import asyncio +import json +from typing import Any, Literal + +import stripe +from fastapi import APIRouter, Body, Depends, HTTPException, Request +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app import branding +from app.auth import CurrentUser, require_auth +from app.config import get_settings +from app.db import get_session, utcnow +from app.logging import get_logger +from app.models import StripeEvent, User + + +log = get_logger("stripe_billing") +router = APIRouter() + + +# Cap stored payload at 16 KiB so a hostile (or buggy) sender can't +# blow up a single row. Same pattern as polar_webhook. +_PAYLOAD_STORE_MAX = 16 * 1024 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _require_configured() -> None: + s = get_settings() + if not s.STRIPE_API_KEY: + raise HTTPException(status_code=503, detail="stripe not configured") + + +def _price_for(cadence: str) -> str: + s = get_settings() + if cadence == "monthly": + if not s.STRIPE_PRICE_MONTHLY: + raise HTTPException(status_code=503, detail="STRIPE_PRICE_MONTHLY not set") + return s.STRIPE_PRICE_MONTHLY + if cadence == "annual": + if not s.STRIPE_PRICE_ANNUAL: + raise HTTPException(status_code=503, detail="STRIPE_PRICE_ANNUAL not set") + return s.STRIPE_PRICE_ANNUAL + raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'") + + +def _stripe_client() -> stripe.StripeClient: + """Per-call client so we read the secret at request time (lets us + rotate the key by editing .env + reloading without rebuilding any + cached client).""" + return stripe.StripeClient(get_settings().STRIPE_API_KEY) + + +# --------------------------------------------------------------------------- +# POST /api/stripe/checkout +# --------------------------------------------------------------------------- + + +class CheckoutRequest(BaseModel): + cadence: Literal["monthly", "annual"] + + +class CheckoutResponse(BaseModel): + url: str + + +@router.post("/api/stripe/checkout", response_model=CheckoutResponse) +async def create_checkout( + body: CheckoutRequest, + session: AsyncSession = Depends(get_session), + cu: CurrentUser = Depends(require_auth), +) -> CheckoutResponse: + _require_configured() + if cu.user is None: + # Admin bearer token has no User row — they shouldn't be buying. + raise HTTPException(status_code=400, detail="admin token cannot purchase") + + user = await session.get(User, cu.user.id) + if user is None: + raise HTTPException(status_code=404, detail="user_not_found") + + price_id = _price_for(body.cadence) + client = _stripe_client() + + # Pass `customer` if we already minted one for this user (avoids + # creating duplicate Stripe customers on repeat checkouts); + # otherwise let Stripe create it via `customer_email`. + create_kwargs: dict[str, Any] = { + "mode": "subscription", + "line_items": [{"price": price_id, "quantity": 1}], + "client_reference_id": str(user.id), + "success_url": f"{branding.SITE_URL}/settings?upgraded=1", + "cancel_url": f"{branding.SITE_URL}/pricing", + # Lets us paste in a referral coupon at checkout once the + # referral redemption flow ships. + "allow_promotion_codes": True, + } + if user.stripe_customer_id: + create_kwargs["customer"] = user.stripe_customer_id + else: + create_kwargs["customer_email"] = user.email + + try: + sess = await asyncio.to_thread( + client.checkout.sessions.create, params=create_kwargs, + ) + except stripe.StripeError as e: + log.error("stripe.checkout.create_failed", user_id=user.id, error=str(e)) + raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}") + + if not sess.url: + raise HTTPException(status_code=502, detail="stripe returned no checkout URL") + log.info("stripe.checkout.created", user_id=user.id, session_id=sess.id, + cadence=body.cadence) + return CheckoutResponse(url=sess.url) + + +# --------------------------------------------------------------------------- +# POST /api/stripe/portal +# --------------------------------------------------------------------------- + + +class PortalResponse(BaseModel): + url: str + + +@router.post("/api/stripe/portal", response_model=PortalResponse) +async def create_portal_session( + session: AsyncSession = Depends(get_session), + cu: CurrentUser = Depends(require_auth), +) -> PortalResponse: + _require_configured() + if cu.user is None: + raise HTTPException(status_code=400, detail="admin token has no portal") + + user = await session.get(User, cu.user.id) + if user is None or not user.stripe_customer_id: + raise HTTPException( + status_code=404, + detail="no_stripe_customer — start a subscription first", + ) + + client = _stripe_client() + try: + portal = await asyncio.to_thread( + client.billing_portal.sessions.create, + params={ + "customer": user.stripe_customer_id, + "return_url": f"{branding.SITE_URL}/settings", + }, + ) + except stripe.StripeError as e: + log.error("stripe.portal.create_failed", user_id=user.id, error=str(e)) + raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}") + + return PortalResponse(url=portal.url) + + +# --------------------------------------------------------------------------- +# POST /api/stripe/webhook +# --------------------------------------------------------------------------- + + +async def _find_user( + session: AsyncSession, + *, + client_ref: str | None = None, + customer_id: str | None = None, +) -> User | None: + """Find the User row this event belongs to. + + `client_reference_id` is the most reliable join key — we set it + to `str(user.id)` at checkout creation. After the first event we + also know `stripe_customer_id`, which subsequent subscription / + invoice events arrive carrying.""" + if client_ref: + try: + uid = int(client_ref) + except ValueError: + uid = None + if uid is not None: + u = await session.get(User, uid) + if u is not None: + return u + if customer_id: + row = (await session.execute( + select(User).where(User.stripe_customer_id == customer_id) + )).scalar_one_or_none() + return row + return None + + +async def _grant_paid( + user: User, + *, + customer_id: str | None, + subscription_id: str | None, +) -> None: + user.tier = "paid" + if customer_id and user.stripe_customer_id != customer_id: + user.stripe_customer_id = customer_id + if subscription_id and user.stripe_subscription_id != subscription_id: + user.stripe_subscription_id = subscription_id + + +async def _revoke_paid(user: User) -> None: + user.tier = "free" + user.stripe_subscription_id = None + # Keep stripe_customer_id so a re-subscription matches this row. + + +async def _handle_checkout_completed( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + user = await _find_user( + session, + client_ref=obj.get("client_reference_id"), + customer_id=obj.get("customer"), + ) + if user is None: + log.warning("stripe.user_not_found", event=event_type) + return + await _grant_paid( + user, + customer_id=obj.get("customer"), + subscription_id=obj.get("subscription"), + ) + + +async def _handle_subscription_event( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + """customer.subscription.created / .updated — flip to paid if the + Stripe-side status says the subscription is active/trialing; drop + to free if it's an end-state.""" + user = await _find_user(session, customer_id=obj.get("customer")) + if user is None: + log.warning("stripe.user_not_found", event=event_type, + customer_id=obj.get("customer")) + return + status = obj.get("status") + # Stripe statuses: trialing, active, past_due, canceled, unpaid, + # incomplete, incomplete_expired, paused. Treat trialing/active as + # paid; everything else holds tier the same until we get an explicit + # subscription.deleted (which fires after the final state lands). + if status in ("trialing", "active"): + await _grant_paid( + user, + customer_id=obj.get("customer"), + subscription_id=obj.get("id"), + ) + + +async def _handle_subscription_deleted( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + user = await _find_user(session, customer_id=obj.get("customer")) + if user is None: + log.warning("stripe.user_not_found", event=event_type, + customer_id=obj.get("customer")) + return + await _revoke_paid(user) + + +async def _handle_audit_only( + session: AsyncSession, event_type: str, obj: dict[str, Any], +) -> None: + """invoice.paid / invoice.payment_failed / charge.refunded — we + record these in stripe_events for the audit log but the tier doesn't + move until subscription.deleted fires.""" + return None + + +_HANDLERS = { + "checkout.session.completed": _handle_checkout_completed, + "customer.subscription.created": _handle_subscription_event, + "customer.subscription.updated": _handle_subscription_event, + "customer.subscription.deleted": _handle_subscription_deleted, + "invoice.paid": _handle_audit_only, + "invoice.payment_failed": _handle_audit_only, + "charge.refunded": _handle_audit_only, +} + + +@router.post("/api/stripe/webhook") +async def stripe_webhook( + request: Request, + session: AsyncSession = Depends(get_session), +) -> dict[str, str]: + s = get_settings() + if not s.STRIPE_WEBHOOK_SECRET: + raise HTTPException(status_code=503, detail="stripe webhook not configured") + + sig = request.headers.get("stripe-signature", "") + if not sig: + raise HTTPException(status_code=400, detail="missing stripe-signature header") + + body = await request.body() + # construct_event handles HMAC verification + timestamp tolerance. + # We then re-parse the body as plain JSON for handler dispatch — + # the Stripe SDK's StripeObject doesn't expose dict.get(), and + # round-tripping through json gives us simple, typed-dict access. + try: + stripe.Webhook.construct_event( + payload=body, sig_header=sig, secret=s.STRIPE_WEBHOOK_SECRET, + ) + except stripe.SignatureVerificationError: + raise HTTPException(status_code=401, detail="bad signature") + except ValueError: + raise HTTPException(status_code=400, detail="invalid payload") + + envelope = json.loads(body) + event_id = envelope.get("id") or "" + event_type = envelope.get("type") or "unknown" + obj = (envelope.get("data") or {}).get("object") or {} + + if not event_id: + raise HTTPException(status_code=400, detail="event missing id") + + # Idempotency: insert audit row first. UNIQUE on event_id makes a + # replay of the same Stripe event id a no-op (Stripe retries on + # non-2xx, so always 2xx after first successful processing). + audit = StripeEvent( + event_id=event_id, + event_type=event_type, + received_at=utcnow(), + payload=body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX], + ) + session.add(audit) + try: + await session.flush() + except IntegrityError: + await session.rollback() + log.info("stripe.duplicate_delivery", event_id=event_id, type=event_type) + return {"status": "duplicate"} + + handler = _HANDLERS.get(event_type) + if handler is None: + audit.processed_at = utcnow() + await session.commit() + log.info("stripe.event_unhandled", type=event_type, id=event_id) + return {"status": "ignored"} + + try: + await handler(session, event_type, obj) + except Exception as e: + audit.error = str(e)[:1024] + await session.commit() + log.exception("stripe.handler_error", type=event_type, id=event_id) + # Ack 200 — we don't want Stripe retrying a handler that broke + # the same way on every delivery. An operator triages from the + # `error` column. + return {"status": "handler_error"} + + audit.processed_at = utcnow() + await session.commit() + log.info("stripe.processed", type=event_type, id=event_id) + return {"status": "ok"} diff --git a/app/templates/pricing.html b/app/templates/pricing.html index a26e106..6bb4411 100644 --- a/app/templates/pricing.html +++ b/app/templates/pricing.html @@ -50,9 +50,8 @@
Full-day news feed, hourly strategic log, follow-up chat, and AI portfolio analysis.
£7 / month
- Or £70 / year — two months free. Prices - in GBP, VAT where applicable. Checkout opens with the payments - rollout. + Or £70 / year — two months free. + Prices in GBP, VAT where applicable.
Everything in Free, plus
@@ -72,16 +71,56 @@ or a personal recommendation under FSMA / FCA COBS.

- {% if cu and (cu.user or cu.is_admin) %} - Manage account + {% if paid %} + Manage subscription + {% elif cu and cu.user %} + + {% else %} - Sign up — paid unlocks soon + Sign in to subscribe {% endif %}
+ +

Free vs Paid at a glance

diff --git a/pyproject.toml b/pyproject.toml index cfa65d6..5a0b50f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "email-validator>=2.2", "aiosmtplib>=3.0", "redis[hiredis]>=5.2", + "stripe>=11.0", ] [project.optional-dependencies] diff --git a/tests/test_stripe_billing.py b/tests/test_stripe_billing.py new file mode 100644 index 0000000..61bf410 --- /dev/null +++ b/tests/test_stripe_billing.py @@ -0,0 +1,322 @@ +"""Stripe billing endpoints: signature verification, idempotency, tier +flips, and checkout creation. + +Same integration-style scaffold as test_polar_webhook.py — real router +over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal +sessions.create) are mocked so the suite never makes a real HTTP call. +""" +from __future__ import annotations + +import asyncio +import hashlib +import hmac +import json +import time +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + + +_API_KEY = "sk_test_dummy_for_unit_tests" +_WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests" +_PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx" +_PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx" + + +def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str: + """Produce a Stripe-Signature header matching the bytes signed. + Format: `t=,v1=` over `.`.""" + ts = ts if ts is not None else int(time.time()) + signed = f"{ts}.{body.decode('utf-8')}" + mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"), + hashlib.sha256).hexdigest() + return f"t={ts},v1={mac}" + + +def _build_app(tmp_path): + from fastapi import FastAPI + from fastapi.testclient import TestClient + from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + + from app import db as db_mod + from app.auth import sign_session + from app.config import get_settings + from app.db import Base + from app.models import User + from app.routers import stripe_billing as stripe_router + + s = get_settings() + s.STRIPE_API_KEY = _API_KEY # type: ignore[misc] + s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc] + s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc] + s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc] + + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db") + factory = async_sessionmaker(engine, expire_on_commit=False) + db_mod._engine = engine + db_mod._session_factory = factory + + async def _seed(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + async with factory() as session: + session.add(User(id=1, email="buyer@x", tier="free")) + await session.commit() + + asyncio.run(_seed()) + + app = FastAPI() + app.include_router(stripe_router.router) + return TestClient(app), factory, sign_session(1) + + +def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET, + sig: str | None = None): + # Stripe's SDK requires a top-level `object: "event"` field to know + # this is a v1 webhook envelope — tests that omit it fail in + # construct_event before the signature check matters. We inject the + # default here so individual tests can stay terse. + body.setdefault("object", "event") + raw = json.dumps(body).encode("utf-8") + sig = sig if sig is not None else _stripe_sig(raw, secret) + return client.post( + "/api/stripe/webhook", + content=raw, + headers={"stripe-signature": sig, "content-type": "application/json"}, + ) + + +# --- signature gate -------------------------------------------------------- + + +def test_webhook_rejects_bad_signature(tmp_path): + client, _, _ = _build_app(tmp_path) + raw = json.dumps({"id": "evt_x", "type": "invoice.paid", + "data": {"object": {}}}).encode("utf-8") + r = client.post( + "/api/stripe/webhook", + content=raw, + headers={ + "stripe-signature": "t=0,v1=deadbeef", + "content-type": "application/json", + }, + ) + assert r.status_code == 401, r.text + + +def test_webhook_rejects_missing_signature(tmp_path): + client, _, _ = _build_app(tmp_path) + r = client.post( + "/api/stripe/webhook", + content=b"{}", + headers={"content-type": "application/json"}, + ) + assert r.status_code == 400, r.text + + +# --- happy paths ----------------------------------------------------------- + + +def test_checkout_session_completed_flips_tier_to_paid(tmp_path): + client, factory, _ = _build_app(tmp_path) + body = { + "id": "evt_checkout_1", + "type": "checkout.session.completed", + "data": { + "object": { + "client_reference_id": "1", + "customer": "cus_abc", + "subscription": "sub_xyz", + } + }, + } + r = _post_webhook(client, body=body) + assert r.status_code == 200, r.text + assert r.json()["status"] == "ok" + + async def _check(): + from sqlalchemy import select + from app.models import User + async with factory() as session: + u = (await session.execute( + select(User).where(User.id == 1) + )).scalar_one() + return u.tier, u.stripe_customer_id, u.stripe_subscription_id + + tier, cid, sid = asyncio.run(_check()) + assert tier == "paid" + assert cid == "cus_abc" + assert sid == "sub_xyz" + + +def test_subscription_deleted_drops_tier_to_free(tmp_path): + client, factory, _ = _build_app(tmp_path) + # First, activate. + _post_webhook(client, body={ + "id": "evt_act", + "type": "checkout.session.completed", + "data": {"object": { + "client_reference_id": "1", + "customer": "cus_abc", + "subscription": "sub_xyz", + }}, + }) + # Then, delete the subscription. + r = _post_webhook(client, body={ + "id": "evt_del", + "type": "customer.subscription.deleted", + "data": {"object": { + "id": "sub_xyz", + "customer": "cus_abc", + "status": "canceled", + }}, + }) + assert r.status_code == 200, r.text + + async def _check(): + from sqlalchemy import select + from app.models import User + async with factory() as session: + u = (await session.execute( + select(User).where(User.id == 1) + )).scalar_one() + return u.tier, u.stripe_customer_id, u.stripe_subscription_id + + tier, cid, sid = asyncio.run(_check()) + assert tier == "free" + # Customer linkage preserved so a future resub matches this row. + assert cid == "cus_abc" + assert sid is None + + +def test_subscription_active_grants_paid(tmp_path): + """customer.subscription.updated with status=active should also + grant paid — covers the case where checkout.session.completed + arrives after subscription.created and we want either to work.""" + client, factory, _ = _build_app(tmp_path) + # Seed the linkage first via checkout (so customer_id is known). + _post_webhook(client, body={ + "id": "evt_ck", + "type": "checkout.session.completed", + "data": {"object": { + "client_reference_id": "1", + "customer": "cus_abc", + "subscription": "sub_xyz", + }}, + }) + # Drop to free manually so we can prove the updated event re-grants. + async def _reset(): + from sqlalchemy import update + from app.models import User + async with factory() as session: + await session.execute( + update(User).where(User.id == 1).values(tier="free") + ) + await session.commit() + asyncio.run(_reset()) + + r = _post_webhook(client, body={ + "id": "evt_upd", + "type": "customer.subscription.updated", + "data": {"object": { + "id": "sub_xyz", + "customer": "cus_abc", + "status": "active", + }}, + }) + assert r.status_code == 200 + + async def _check_tier(): + from sqlalchemy import select + from app.models import User + async with factory() as session: + return (await session.execute( + select(User.tier).where(User.id == 1) + )).scalar_one() + assert asyncio.run(_check_tier()) == "paid" + + +# --- idempotency + unknown ------------------------------------------------ + + +def test_replayed_event_id_is_a_noop(tmp_path): + client, factory, _ = _build_app(tmp_path) + body = { + "id": "evt_dup", + "type": "checkout.session.completed", + "data": {"object": { + "client_reference_id": "1", + "customer": "cus_abc", + "subscription": "sub_xyz", + }}, + } + r1 = _post_webhook(client, body=body) + r2 = _post_webhook(client, body=body) + assert r1.json()["status"] == "ok" + assert r2.json()["status"] == "duplicate" + + async def _count_rows(): + from sqlalchemy import select, func + from app.models import StripeEvent + async with factory() as session: + n = (await session.execute( + select(func.count(StripeEvent.id)) + .where(StripeEvent.event_id == "evt_dup") + )).scalar_one() + return n + assert asyncio.run(_count_rows()) == 1 + + +def test_unknown_event_is_acked(tmp_path): + client, _, _ = _build_app(tmp_path) + r = _post_webhook(client, body={ + "id": "evt_unknown", + "type": "product.something.new", + "data": {"object": {}}, + }) + assert r.status_code == 200 + assert r.json()["status"] == "ignored" + + +# --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------ + + +def test_checkout_endpoint_creates_session_and_returns_url(tmp_path): + client, _, session_cookie = _build_app(tmp_path) + # Mock the Stripe SDK call so no real HTTP goes out. + fake_session = SimpleNamespace( + id="cs_test_123", url="https://checkout.stripe.com/test", + ) + + class _FakeSessions: + @staticmethod + def create(params): # noqa: ANN001 + assert params["mode"] == "subscription" + assert params["line_items"][0]["price"] == _PRICE_MONTHLY + assert params["client_reference_id"] == "1" + assert params["customer_email"] == "buyer@x" + return fake_session + + class _FakeCheckout: + sessions = _FakeSessions() + + class _FakeClient: + checkout = _FakeCheckout() + + with patch("app.routers.stripe_billing._stripe_client", + return_value=_FakeClient()): + r = client.post( + "/api/stripe/checkout", + json={"cadence": "monthly"}, + cookies={"cassandra_session": session_cookie}, + ) + assert r.status_code == 200, r.text + assert r.json()["url"] == "https://checkout.stripe.com/test" + + +def test_checkout_endpoint_requires_login(tmp_path): + client, _, _ = _build_app(tmp_path) + r = client.post("/api/stripe/checkout", json={"cadence": "monthly"}) + # No session cookie → require_auth bounces with 401. + assert r.status_code == 401, r.text