stripe: wire checkout, customer portal, and webhook for read.markets

Stripe is the merchant-on-record for read.markets after Polar/Paddle
both declined the financial-media category. This commit lands the
full subscription flow: an "Upgrade" button on /pricing now opens a
real Stripe-hosted Checkout, completes the subscription, and the
webhook flips user.tier to "paid" idempotently.

Endpoints
- POST /api/stripe/checkout (require_auth) — creates a hosted
  Checkout Session in subscription mode, passes user.id as
  client_reference_id + email as customer_email, returns the URL
  for the page-side JS to redirect to. Reuses an existing
  stripe_customer_id to avoid duplicate Stripe customers on repeat
  checkouts. allow_promotion_codes=True so the referral-credit
  redemption can attach a coupon at checkout once that flow ships.
- POST /api/stripe/portal (require_auth) — mints a Stripe Customer
  Portal session. Used by /settings; returns 404 until the user has
  a stripe_customer_id (i.e. completed at least one checkout).
- POST /api/stripe/webhook — signature-verified via
  stripe.Webhook.construct_event. Idempotent via UNIQUE on
  stripe_events.event_id. Event dispatch:
    checkout.session.completed       → grant paid, store IDs
    customer.subscription.created    → grant paid (active/trialing)
    customer.subscription.updated    → grant paid (active/trialing)
    customer.subscription.deleted    → drop to free, clear sub id
    invoice.paid / failed            → audit only
    charge.refunded                  → audit only
  Stripe-SDK objects don't expose dict.get(); we use the SDK for
  signature verification then re-parse the JSON body for handler
  dispatch — cleaner than reaching into StripeObject internals.

Schema (migration 0019)
- users.stripe_customer_id, users.stripe_subscription_id (nullable
  String(64), UNIQUE on customer_id).
- stripe_events table mirroring polar_events: event_id (unique),
  event_type, received_at, processed_at, error, raw payload
  (truncated to 16 KiB).

Settings (.env)
- STRIPE_API_KEY            (rk_test_… for dev, rk_live_… for GA)
- STRIPE_WEBHOOK_SECRET     (whsec_… from the dashboard endpoint)
- STRIPE_PRICE_MONTHLY      (price_xxx for £7/month)
- STRIPE_PRICE_ANNUAL       (price_xxx for £70/year)

Pricing page
- Free tier CTA unchanged.
- Paid CTA branches three ways: paid → "Manage subscription" to
  /settings; logged-in free → two buttons (£7/mo, £70/yr) that POST
  to /api/stripe/checkout and redirect; anonymous → /login?next=/pricing.
- Inline JS intercepts the button click, calls the checkout
  endpoint, redirects on success, surfaces errors via alert(). No
  Stripe.js dep — we use the hosted-checkout URL directly.

Polar handler stays in place for berengar.io / flyroom.net which
still ship through Polar. polar_* and stripe_* columns coexist
independently on the User row.

Tests
- 9 in tests/test_stripe_billing.py covering: bad signature → 401,
  missing signature → 400, checkout.session.completed flips tier +
  stores IDs, subscription.updated active grants paid,
  subscription.deleted drops to free with customer id preserved,
  replayed event id is no-op (one row in stripe_events),
  unknown event acked 200, checkout endpoint mocks the SDK and
  returns the hosted URL, checkout requires login.
- Full suite: 221 passed, 5 skipped.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Giorgio Gilestro 2026-05-26 18:45:13 +02:00
parent 6c13f855e9
commit 410afe0078
9 changed files with 858 additions and 7 deletions

View file

@ -0,0 +1,56 @@
"""stripe integration: users.stripe_customer_id / stripe_subscription_id,
stripe_events table.
Revision ID: 0019
Revises: 0018
Create Date: 2026-05-26
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "0019"
down_revision: Union[str, None] = "0018"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"users",
sa.Column("stripe_customer_id", sa.String(length=64), nullable=True),
)
op.add_column(
"users",
sa.Column("stripe_subscription_id", sa.String(length=64), nullable=True),
)
op.create_unique_constraint(
"uq_users_stripe_customer", "users", ["stripe_customer_id"],
)
op.create_table(
"stripe_events",
sa.Column("id", sa.BigInteger(), autoincrement=True, primary_key=True),
sa.Column("event_id", sa.String(length=128), nullable=False),
sa.Column("event_type", sa.String(length=64), nullable=False),
sa.Column("received_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("processed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("payload", sa.Text(), nullable=False),
sa.UniqueConstraint("event_id", name="uq_stripe_events_event_id"),
)
op.create_index(
"ix_stripe_events_type_received",
"stripe_events",
["event_type", "received_at"],
)
def downgrade() -> None:
op.drop_index("ix_stripe_events_type_received", table_name="stripe_events")
op.drop_table("stripe_events")
op.drop_constraint("uq_users_stripe_customer", "users", type_="unique")
op.drop_column("users", "stripe_subscription_id")
op.drop_column("users", "stripe_customer_id")

View file

@ -99,6 +99,16 @@ class Settings(BaseSettings):
POLAR_WEBHOOK_SECRET: str = ""
POLAR_API_KEY: str = ""
# Stripe (merchant-on-record for read.markets after Polar/Paddle
# both declined the financial-media category). Test-mode keys are
# `sk_test_*` / `whsec_*`; live-mode keys are `sk_live_*` — swap at
# GA cutover. Empty values make the corresponding endpoints 503 so
# a misconfig is loud rather than silently accepting unsigned events.
STRIPE_API_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
STRIPE_PRICE_MONTHLY: str = "" # price_xxx for £7/month subscription
STRIPE_PRICE_ANNUAL: str = "" # price_xxx for £70/year subscription
# Config file locations (overridable for tests)
BASELINE_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "default.toml")
PORTFOLIO_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "portfolio.toml")

View file

@ -23,6 +23,7 @@ from app.routers import email as email_router
from app.routers import pages as pages_router
from app.routers import polar_webhook as polar_webhook_router
from app.routers import public as public_router
from app.routers import stripe_billing as stripe_router
from app.routers import sync as sync_router
from app.routers import universe as universe_router
from app.services.feeds_bootstrap import bootstrap_feeds
@ -93,6 +94,9 @@ app.include_router(sync_router.router, tags=["portfolio-sync"])
# `/api/polar/webhook` is set on the route itself so the URL Polar
# stores remains stable even if api_router's prefix ever moves.
app.include_router(polar_webhook_router.router, tags=["polar-webhook"])
# Stripe billing (checkout, portal, webhook). Auth lives per-route:
# checkout + portal require_auth, webhook is signature-gated.
app.include_router(stripe_router.router, tags=["stripe-billing"])
# Public router (no auth dep) before pages_router so the marketing/legal
# paths can never collide with future authenticated routes.
app.include_router(public_router.router)

View file

@ -194,11 +194,18 @@ class User(Base):
# we cancel against from /settings.
polar_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
polar_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
# Stripe (merchant-on-record for read.markets). Populated on the
# first checkout.session.completed event via client_reference_id;
# used thereafter to match incoming subscription/invoice events
# back to this row.
stripe_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
stripe_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
__table_args__ = (
UniqueConstraint("email", name="uq_users_email"),
UniqueConstraint("referral_code", name="uq_users_referral_code"),
UniqueConstraint("polar_customer_id", name="uq_users_polar_customer"),
UniqueConstraint("stripe_customer_id", name="uq_users_stripe_customer"),
)
@ -385,3 +392,29 @@ class PolarEvent(Base):
UniqueConstraint("event_id", name="uq_polar_events_event_id"),
Index("ix_polar_events_type_received", "event_type", "received_at"),
)
class StripeEvent(Base):
"""Audit + idempotency table for inbound Stripe webhook deliveries.
Same shape and purpose as PolarEvent Stripe's `event.id` plays the
same role as Standard Webhooks' `webhook-id`. We keep the tables
distinct (rather than a single 'webhook_events' table) so an
operator can look at the audit trail per processor without filtering
on a `source` column."""
__tablename__ = "stripe_events"
id: Mapped[int] = mapped_column(_PK, primary_key=True, autoincrement=True)
event_id: Mapped[str] = mapped_column(String(128), nullable=False)
event_type: Mapped[str] = mapped_column(String(64), nullable=False)
received_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=utcnow, nullable=False,
)
processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
error: Mapped[str | None] = mapped_column(Text)
payload: Mapped[str] = mapped_column(Text, nullable=False)
__table_args__ = (
UniqueConstraint("event_id", name="uq_stripe_events_event_id"),
Index("ix_stripe_events_type_received", "event_type", "received_at"),
)

View file

@ -15,6 +15,7 @@ from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from app.auth import CurrentUser, maybe_current_user
from app.services.access import is_paid_active
from app.templates_env import templates
@ -33,7 +34,9 @@ async def pricing_page(
request: Request,
cu: CurrentUser | None = Depends(maybe_current_user),
):
return templates.TemplateResponse(request, "pricing.html", _ctx(request, cu))
ctx = _ctx(request, cu)
ctx["paid"] = is_paid_active(cu)
return templates.TemplateResponse(request, "pricing.html", ctx)
@router.get("/about", response_class=HTMLResponse)

View file

@ -0,0 +1,383 @@
"""Stripe billing endpoints — checkout, webhook, customer portal.
Stripe is the merchant-on-record for read.markets (after Polar/Paddle
both declined the financial-media category). We delegate payment UI to
Stripe-hosted Checkout and Customer Portal; the only state we keep on
our side is `users.stripe_customer_id` / `users.stripe_subscription_id`
so we can match incoming webhooks back to the right user.
The Stripe SDK is sync; we wrap calls in `asyncio.to_thread` so the
event loop doesn't block while Stripe answers. For our request volume
this is more reliable than the SDK's nascent async surface.
Routes
- POST /api/stripe/checkout logged-in user upgrades. Body: {cadence}.
- POST /api/stripe/webhook Stripe us, signature-verified.
- POST /api/stripe/portal logged-in user opens the customer portal.
"""
from __future__ import annotations
import asyncio
import json
from typing import Any, Literal
import stripe
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app import branding
from app.auth import CurrentUser, require_auth
from app.config import get_settings
from app.db import get_session, utcnow
from app.logging import get_logger
from app.models import StripeEvent, User
log = get_logger("stripe_billing")
router = APIRouter()
# Cap stored payload at 16 KiB so a hostile (or buggy) sender can't
# blow up a single row. Same pattern as polar_webhook.
_PAYLOAD_STORE_MAX = 16 * 1024
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _require_configured() -> None:
s = get_settings()
if not s.STRIPE_API_KEY:
raise HTTPException(status_code=503, detail="stripe not configured")
def _price_for(cadence: str) -> str:
s = get_settings()
if cadence == "monthly":
if not s.STRIPE_PRICE_MONTHLY:
raise HTTPException(status_code=503, detail="STRIPE_PRICE_MONTHLY not set")
return s.STRIPE_PRICE_MONTHLY
if cadence == "annual":
if not s.STRIPE_PRICE_ANNUAL:
raise HTTPException(status_code=503, detail="STRIPE_PRICE_ANNUAL not set")
return s.STRIPE_PRICE_ANNUAL
raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'")
def _stripe_client() -> stripe.StripeClient:
"""Per-call client so we read the secret at request time (lets us
rotate the key by editing .env + reloading without rebuilding any
cached client)."""
return stripe.StripeClient(get_settings().STRIPE_API_KEY)
# ---------------------------------------------------------------------------
# POST /api/stripe/checkout
# ---------------------------------------------------------------------------
class CheckoutRequest(BaseModel):
cadence: Literal["monthly", "annual"]
class CheckoutResponse(BaseModel):
url: str
@router.post("/api/stripe/checkout", response_model=CheckoutResponse)
async def create_checkout(
body: CheckoutRequest,
session: AsyncSession = Depends(get_session),
cu: CurrentUser = Depends(require_auth),
) -> CheckoutResponse:
_require_configured()
if cu.user is None:
# Admin bearer token has no User row — they shouldn't be buying.
raise HTTPException(status_code=400, detail="admin token cannot purchase")
user = await session.get(User, cu.user.id)
if user is None:
raise HTTPException(status_code=404, detail="user_not_found")
price_id = _price_for(body.cadence)
client = _stripe_client()
# Pass `customer` if we already minted one for this user (avoids
# creating duplicate Stripe customers on repeat checkouts);
# otherwise let Stripe create it via `customer_email`.
create_kwargs: dict[str, Any] = {
"mode": "subscription",
"line_items": [{"price": price_id, "quantity": 1}],
"client_reference_id": str(user.id),
"success_url": f"{branding.SITE_URL}/settings?upgraded=1",
"cancel_url": f"{branding.SITE_URL}/pricing",
# Lets us paste in a referral coupon at checkout once the
# referral redemption flow ships.
"allow_promotion_codes": True,
}
if user.stripe_customer_id:
create_kwargs["customer"] = user.stripe_customer_id
else:
create_kwargs["customer_email"] = user.email
try:
sess = await asyncio.to_thread(
client.checkout.sessions.create, params=create_kwargs,
)
except stripe.StripeError as e:
log.error("stripe.checkout.create_failed", user_id=user.id, error=str(e))
raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
if not sess.url:
raise HTTPException(status_code=502, detail="stripe returned no checkout URL")
log.info("stripe.checkout.created", user_id=user.id, session_id=sess.id,
cadence=body.cadence)
return CheckoutResponse(url=sess.url)
# ---------------------------------------------------------------------------
# POST /api/stripe/portal
# ---------------------------------------------------------------------------
class PortalResponse(BaseModel):
url: str
@router.post("/api/stripe/portal", response_model=PortalResponse)
async def create_portal_session(
session: AsyncSession = Depends(get_session),
cu: CurrentUser = Depends(require_auth),
) -> PortalResponse:
_require_configured()
if cu.user is None:
raise HTTPException(status_code=400, detail="admin token has no portal")
user = await session.get(User, cu.user.id)
if user is None or not user.stripe_customer_id:
raise HTTPException(
status_code=404,
detail="no_stripe_customer — start a subscription first",
)
client = _stripe_client()
try:
portal = await asyncio.to_thread(
client.billing_portal.sessions.create,
params={
"customer": user.stripe_customer_id,
"return_url": f"{branding.SITE_URL}/settings",
},
)
except stripe.StripeError as e:
log.error("stripe.portal.create_failed", user_id=user.id, error=str(e))
raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
return PortalResponse(url=portal.url)
# ---------------------------------------------------------------------------
# POST /api/stripe/webhook
# ---------------------------------------------------------------------------
async def _find_user(
session: AsyncSession,
*,
client_ref: str | None = None,
customer_id: str | None = None,
) -> User | None:
"""Find the User row this event belongs to.
`client_reference_id` is the most reliable join key we set it
to `str(user.id)` at checkout creation. After the first event we
also know `stripe_customer_id`, which subsequent subscription /
invoice events arrive carrying."""
if client_ref:
try:
uid = int(client_ref)
except ValueError:
uid = None
if uid is not None:
u = await session.get(User, uid)
if u is not None:
return u
if customer_id:
row = (await session.execute(
select(User).where(User.stripe_customer_id == customer_id)
)).scalar_one_or_none()
return row
return None
async def _grant_paid(
user: User,
*,
customer_id: str | None,
subscription_id: str | None,
) -> None:
user.tier = "paid"
if customer_id and user.stripe_customer_id != customer_id:
user.stripe_customer_id = customer_id
if subscription_id and user.stripe_subscription_id != subscription_id:
user.stripe_subscription_id = subscription_id
async def _revoke_paid(user: User) -> None:
user.tier = "free"
user.stripe_subscription_id = None
# Keep stripe_customer_id so a re-subscription matches this row.
async def _handle_checkout_completed(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
user = await _find_user(
session,
client_ref=obj.get("client_reference_id"),
customer_id=obj.get("customer"),
)
if user is None:
log.warning("stripe.user_not_found", event=event_type)
return
await _grant_paid(
user,
customer_id=obj.get("customer"),
subscription_id=obj.get("subscription"),
)
async def _handle_subscription_event(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
"""customer.subscription.created / .updated — flip to paid if the
Stripe-side status says the subscription is active/trialing; drop
to free if it's an end-state."""
user = await _find_user(session, customer_id=obj.get("customer"))
if user is None:
log.warning("stripe.user_not_found", event=event_type,
customer_id=obj.get("customer"))
return
status = obj.get("status")
# Stripe statuses: trialing, active, past_due, canceled, unpaid,
# incomplete, incomplete_expired, paused. Treat trialing/active as
# paid; everything else holds tier the same until we get an explicit
# subscription.deleted (which fires after the final state lands).
if status in ("trialing", "active"):
await _grant_paid(
user,
customer_id=obj.get("customer"),
subscription_id=obj.get("id"),
)
async def _handle_subscription_deleted(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
user = await _find_user(session, customer_id=obj.get("customer"))
if user is None:
log.warning("stripe.user_not_found", event=event_type,
customer_id=obj.get("customer"))
return
await _revoke_paid(user)
async def _handle_audit_only(
session: AsyncSession, event_type: str, obj: dict[str, Any],
) -> None:
"""invoice.paid / invoice.payment_failed / charge.refunded — we
record these in stripe_events for the audit log but the tier doesn't
move until subscription.deleted fires."""
return None
_HANDLERS = {
"checkout.session.completed": _handle_checkout_completed,
"customer.subscription.created": _handle_subscription_event,
"customer.subscription.updated": _handle_subscription_event,
"customer.subscription.deleted": _handle_subscription_deleted,
"invoice.paid": _handle_audit_only,
"invoice.payment_failed": _handle_audit_only,
"charge.refunded": _handle_audit_only,
}
@router.post("/api/stripe/webhook")
async def stripe_webhook(
request: Request,
session: AsyncSession = Depends(get_session),
) -> dict[str, str]:
s = get_settings()
if not s.STRIPE_WEBHOOK_SECRET:
raise HTTPException(status_code=503, detail="stripe webhook not configured")
sig = request.headers.get("stripe-signature", "")
if not sig:
raise HTTPException(status_code=400, detail="missing stripe-signature header")
body = await request.body()
# construct_event handles HMAC verification + timestamp tolerance.
# We then re-parse the body as plain JSON for handler dispatch —
# the Stripe SDK's StripeObject doesn't expose dict.get(), and
# round-tripping through json gives us simple, typed-dict access.
try:
stripe.Webhook.construct_event(
payload=body, sig_header=sig, secret=s.STRIPE_WEBHOOK_SECRET,
)
except stripe.SignatureVerificationError:
raise HTTPException(status_code=401, detail="bad signature")
except ValueError:
raise HTTPException(status_code=400, detail="invalid payload")
envelope = json.loads(body)
event_id = envelope.get("id") or ""
event_type = envelope.get("type") or "unknown"
obj = (envelope.get("data") or {}).get("object") or {}
if not event_id:
raise HTTPException(status_code=400, detail="event missing id")
# Idempotency: insert audit row first. UNIQUE on event_id makes a
# replay of the same Stripe event id a no-op (Stripe retries on
# non-2xx, so always 2xx after first successful processing).
audit = StripeEvent(
event_id=event_id,
event_type=event_type,
received_at=utcnow(),
payload=body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX],
)
session.add(audit)
try:
await session.flush()
except IntegrityError:
await session.rollback()
log.info("stripe.duplicate_delivery", event_id=event_id, type=event_type)
return {"status": "duplicate"}
handler = _HANDLERS.get(event_type)
if handler is None:
audit.processed_at = utcnow()
await session.commit()
log.info("stripe.event_unhandled", type=event_type, id=event_id)
return {"status": "ignored"}
try:
await handler(session, event_type, obj)
except Exception as e:
audit.error = str(e)[:1024]
await session.commit()
log.exception("stripe.handler_error", type=event_type, id=event_id)
# Ack 200 — we don't want Stripe retrying a handler that broke
# the same way on every delivery. An operator triages from the
# `error` column.
return {"status": "handler_error"}
audit.processed_at = utcnow()
await session.commit()
log.info("stripe.processed", type=event_type, id=event_id)
return {"status": "ok"}

View file

@ -50,9 +50,8 @@
<div class="tier-card__tagline">Full-day news feed, hourly strategic log, follow-up chat, and AI portfolio analysis.</div>
<div class="tier-card__price">&pound;7<span class="tier-card__price-unit"> / month</span></div>
<div class="tier-card__price-hint">
Or <strong>&pound;70 / year</strong> &mdash; two months free. Prices
in GBP, VAT where applicable. Checkout opens with the payments
rollout.
Or <strong>&pound;70 / year</strong> &mdash; two months free.
Prices in GBP, VAT where applicable.
</div>
<div class="tier-card__divider"></div>
<div class="tier-card__list-head">Everything in Free, plus</div>
@ -72,16 +71,56 @@
or a personal recommendation under FSMA / FCA COBS.
</p>
<div class="tier-card__cta">
{% if cu and (cu.user or cu.is_admin) %}
<a class="btn-secondary btn-block" href="/settings">Manage account</a>
{% if paid %}
<a class="btn-secondary btn-block" href="/settings">Manage subscription</a>
{% elif cu and cu.user %}
<button class="btn-primary btn-block" type="button"
data-stripe-checkout="monthly">Subscribe &mdash; &pound;7/month</button>
<button class="btn-secondary btn-block" type="button"
data-stripe-checkout="annual"
style="margin-top:10px;">or &pound;70/year (two months free)</button>
{% else %}
<a class="btn-primary btn-block" href="/login">Sign up &mdash; paid unlocks soon</a>
<a class="btn-primary btn-block" href="/login?next=/pricing">Sign in to subscribe</a>
{% endif %}
</div>
</div>
</section>
<script>
(function () {
// Wire the two upgrade buttons to /api/stripe/checkout. Stripe returns
// a hosted-checkout URL; we just redirect there. No Stripe.js needed.
document.querySelectorAll('[data-stripe-checkout]').forEach(function (btn) {
btn.addEventListener('click', async function () {
var cadence = btn.getAttribute('data-stripe-checkout');
btn.disabled = true;
var prev = btn.textContent;
btn.textContent = 'Opening checkout…';
try {
var r = await fetch('/api/stripe/checkout', {
method: 'POST',
headers: {'content-type': 'application/json'},
body: JSON.stringify({cadence: cadence}),
credentials: 'same-origin',
});
if (!r.ok) {
var detail = '';
try { detail = (await r.json()).detail || ''; } catch (e) {}
throw new Error('Checkout failed: ' + (detail || r.status));
}
var data = await r.json();
window.location.href = data.url;
} catch (e) {
alert(e.message || 'Could not start checkout. Please try again.');
btn.disabled = false;
btn.textContent = prev;
}
});
});
})();
</script>
<section class="public-section">
<h2 class="public-section__head">Free vs Paid at a glance</h2>
<table class="compare-table">

View file

@ -23,6 +23,7 @@ dependencies = [
"email-validator>=2.2",
"aiosmtplib>=3.0",
"redis[hiredis]>=5.2",
"stripe>=11.0",
]
[project.optional-dependencies]

View file

@ -0,0 +1,322 @@
"""Stripe billing endpoints: signature verification, idempotency, tier
flips, and checkout creation.
Same integration-style scaffold as test_polar_webhook.py real router
over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal
sessions.create) are mocked so the suite never makes a real HTTP call.
"""
from __future__ import annotations
import asyncio
import hashlib
import hmac
import json
import time
from types import SimpleNamespace
from unittest.mock import patch
import pytest
_API_KEY = "sk_test_dummy_for_unit_tests"
_WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests"
_PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx"
_PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx"
def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str:
"""Produce a Stripe-Signature header matching the bytes signed.
Format: `t=<ts>,v1=<hex hmac sha256>` over `<ts>.<body>`."""
ts = ts if ts is not None else int(time.time())
signed = f"{ts}.{body.decode('utf-8')}"
mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"),
hashlib.sha256).hexdigest()
return f"t={ts},v1={mac}"
def _build_app(tmp_path):
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.auth import sign_session
from app.config import get_settings
from app.db import Base
from app.models import User
from app.routers import stripe_billing as stripe_router
s = get_settings()
s.STRIPE_API_KEY = _API_KEY # type: ignore[misc]
s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc]
s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc]
s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc]
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _seed():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with factory() as session:
session.add(User(id=1, email="buyer@x", tier="free"))
await session.commit()
asyncio.run(_seed())
app = FastAPI()
app.include_router(stripe_router.router)
return TestClient(app), factory, sign_session(1)
def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET,
sig: str | None = None):
# Stripe's SDK requires a top-level `object: "event"` field to know
# this is a v1 webhook envelope — tests that omit it fail in
# construct_event before the signature check matters. We inject the
# default here so individual tests can stay terse.
body.setdefault("object", "event")
raw = json.dumps(body).encode("utf-8")
sig = sig if sig is not None else _stripe_sig(raw, secret)
return client.post(
"/api/stripe/webhook",
content=raw,
headers={"stripe-signature": sig, "content-type": "application/json"},
)
# --- signature gate --------------------------------------------------------
def test_webhook_rejects_bad_signature(tmp_path):
client, _, _ = _build_app(tmp_path)
raw = json.dumps({"id": "evt_x", "type": "invoice.paid",
"data": {"object": {}}}).encode("utf-8")
r = client.post(
"/api/stripe/webhook",
content=raw,
headers={
"stripe-signature": "t=0,v1=deadbeef",
"content-type": "application/json",
},
)
assert r.status_code == 401, r.text
def test_webhook_rejects_missing_signature(tmp_path):
client, _, _ = _build_app(tmp_path)
r = client.post(
"/api/stripe/webhook",
content=b"{}",
headers={"content-type": "application/json"},
)
assert r.status_code == 400, r.text
# --- happy paths -----------------------------------------------------------
def test_checkout_session_completed_flips_tier_to_paid(tmp_path):
client, factory, _ = _build_app(tmp_path)
body = {
"id": "evt_checkout_1",
"type": "checkout.session.completed",
"data": {
"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}
},
}
r = _post_webhook(client, body=body)
assert r.status_code == 200, r.text
assert r.json()["status"] == "ok"
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
tier, cid, sid = asyncio.run(_check())
assert tier == "paid"
assert cid == "cus_abc"
assert sid == "sub_xyz"
def test_subscription_deleted_drops_tier_to_free(tmp_path):
client, factory, _ = _build_app(tmp_path)
# First, activate.
_post_webhook(client, body={
"id": "evt_act",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
})
# Then, delete the subscription.
r = _post_webhook(client, body={
"id": "evt_del",
"type": "customer.subscription.deleted",
"data": {"object": {
"id": "sub_xyz",
"customer": "cus_abc",
"status": "canceled",
}},
})
assert r.status_code == 200, r.text
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
tier, cid, sid = asyncio.run(_check())
assert tier == "free"
# Customer linkage preserved so a future resub matches this row.
assert cid == "cus_abc"
assert sid is None
def test_subscription_active_grants_paid(tmp_path):
"""customer.subscription.updated with status=active should also
grant paid covers the case where checkout.session.completed
arrives after subscription.created and we want either to work."""
client, factory, _ = _build_app(tmp_path)
# Seed the linkage first via checkout (so customer_id is known).
_post_webhook(client, body={
"id": "evt_ck",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
})
# Drop to free manually so we can prove the updated event re-grants.
async def _reset():
from sqlalchemy import update
from app.models import User
async with factory() as session:
await session.execute(
update(User).where(User.id == 1).values(tier="free")
)
await session.commit()
asyncio.run(_reset())
r = _post_webhook(client, body={
"id": "evt_upd",
"type": "customer.subscription.updated",
"data": {"object": {
"id": "sub_xyz",
"customer": "cus_abc",
"status": "active",
}},
})
assert r.status_code == 200
async def _check_tier():
from sqlalchemy import select
from app.models import User
async with factory() as session:
return (await session.execute(
select(User.tier).where(User.id == 1)
)).scalar_one()
assert asyncio.run(_check_tier()) == "paid"
# --- idempotency + unknown ------------------------------------------------
def test_replayed_event_id_is_a_noop(tmp_path):
client, factory, _ = _build_app(tmp_path)
body = {
"id": "evt_dup",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
}
r1 = _post_webhook(client, body=body)
r2 = _post_webhook(client, body=body)
assert r1.json()["status"] == "ok"
assert r2.json()["status"] == "duplicate"
async def _count_rows():
from sqlalchemy import select, func
from app.models import StripeEvent
async with factory() as session:
n = (await session.execute(
select(func.count(StripeEvent.id))
.where(StripeEvent.event_id == "evt_dup")
)).scalar_one()
return n
assert asyncio.run(_count_rows()) == 1
def test_unknown_event_is_acked(tmp_path):
client, _, _ = _build_app(tmp_path)
r = _post_webhook(client, body={
"id": "evt_unknown",
"type": "product.something.new",
"data": {"object": {}},
})
assert r.status_code == 200
assert r.json()["status"] == "ignored"
# --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------
def test_checkout_endpoint_creates_session_and_returns_url(tmp_path):
client, _, session_cookie = _build_app(tmp_path)
# Mock the Stripe SDK call so no real HTTP goes out.
fake_session = SimpleNamespace(
id="cs_test_123", url="https://checkout.stripe.com/test",
)
class _FakeSessions:
@staticmethod
def create(params): # noqa: ANN001
assert params["mode"] == "subscription"
assert params["line_items"][0]["price"] == _PRICE_MONTHLY
assert params["client_reference_id"] == "1"
assert params["customer_email"] == "buyer@x"
return fake_session
class _FakeCheckout:
sessions = _FakeSessions()
class _FakeClient:
checkout = _FakeCheckout()
with patch("app.routers.stripe_billing._stripe_client",
return_value=_FakeClient()):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "monthly"},
cookies={"cassandra_session": session_cookie},
)
assert r.status_code == 200, r.text
assert r.json()["url"] == "https://checkout.stripe.com/test"
def test_checkout_endpoint_requires_login(tmp_path):
client, _, _ = _build_app(tmp_path)
r = client.post("/api/stripe/checkout", json={"cadence": "monthly"})
# No session cookie → require_auth bounces with 401.
assert r.status_code == 401, r.text