Bundles three related pieces that came out of the operator's first
end-to-end test of the paid flow:
1. Manage subscription button on /settings (paid users with a real
Stripe sub — i.e. not credit-granted access). POSTs to the existing
/api/stripe/portal endpoint; Stripe-hosted customer portal handles
card updates, cancellation, monthly↔annual switch, invoice history.
Replaces the stale "Paid features unlock with Paddle (D.3) or
invite credits" hint for free users with a live link to /pricing.
2. Per-cadence cooling-off treatment:
- **Annual £70**: 14-day free trial via
subscription_data.trial_period_days=14. No money moves during
the trial, so the CCR 2013 14-day refund question doesn't arise
(nothing paid = nothing to refund). Card is still required at
checkout so Stripe can charge on day 15.
- **Monthly £7**: bills immediately. A 14-day trial there would
give away ~50% of cycle one. Instead, /pricing now carries a
required tick-box above the Subscribe buttons (subscribe stays
disabled until checked) — by ticking, the user expressly
consents to begin performance immediately and acknowledges that
this extinguishes their statutory 14-day right under Reg 36
CCR 2013. Consent collected on our own page (not via Stripe's
account-wide consent_collection.terms_of_service) so each
product can keep its own Terms URL as we add more.
3. T&C §6 clause 1 split into 1a (annual / trial substitute) +
1b (monthly / Reg 36 waiver via on-page tick-box). Clause 2
(post-cooling-off cancellation) unchanged.
Settings page shows "Free trial — N days remaining" while the
sub is in `trialing` status, falling back to "Paid subscription
active." once it transitions to active. Countdown is computed
server-side from User.stripe_trial_end_at (new column, migration
0020) populated by the subscription.created/updated webhook from
the Stripe trial_end timestamp; cleared on the trialing→active
transition and on revoke.
Drive-by: fixed a structlog kwarg-name collision on
`log.warning(..., event=event_type, ...)` in both polar_webhook.py
and stripe_billing.py — `event` is structlog's positional event
name and "got multiple values" crashed the user-not-found log
path. Renamed to `event_type=` everywhere it appeared. Caught by
the new trialing-stores-trial-end test.
Tests
- 4 new in test_stripe_billing.py covering monthly (no trial, no
consent_collection), annual (trial, no consent), trialing stores
trial_end, trialing→active clears trial_end.
- 1 existing test renamed + reworked for the consent split.
- Full suite: 224 passed, 5 skipped.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
465 lines
16 KiB
Python
465 lines
16 KiB
Python
"""Stripe billing endpoints: signature verification, idempotency, tier
|
||
flips, and checkout creation.
|
||
|
||
Same integration-style scaffold as test_polar_webhook.py — real router
|
||
over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal
|
||
sessions.create) are mocked so the suite never makes a real HTTP call.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import time
|
||
from types import SimpleNamespace
|
||
from unittest.mock import patch
|
||
|
||
import pytest
|
||
|
||
|
||
_API_KEY = "sk_test_dummy_for_unit_tests"
|
||
_WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests"
|
||
_PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx"
|
||
_PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx"
|
||
|
||
|
||
def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str:
|
||
"""Produce a Stripe-Signature header matching the bytes signed.
|
||
Format: `t=<ts>,v1=<hex hmac sha256>` over `<ts>.<body>`."""
|
||
ts = ts if ts is not None else int(time.time())
|
||
signed = f"{ts}.{body.decode('utf-8')}"
|
||
mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"),
|
||
hashlib.sha256).hexdigest()
|
||
return f"t={ts},v1={mac}"
|
||
|
||
|
||
def _build_app(tmp_path):
|
||
from fastapi import FastAPI
|
||
from fastapi.testclient import TestClient
|
||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||
|
||
from app import db as db_mod
|
||
from app.auth import sign_session
|
||
from app.config import get_settings
|
||
from app.db import Base
|
||
from app.models import User
|
||
from app.routers import stripe_billing as stripe_router
|
||
|
||
s = get_settings()
|
||
s.STRIPE_API_KEY = _API_KEY # type: ignore[misc]
|
||
s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc]
|
||
s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc]
|
||
s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc]
|
||
|
||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db")
|
||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||
db_mod._engine = engine
|
||
db_mod._session_factory = factory
|
||
|
||
async def _seed():
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
async with factory() as session:
|
||
session.add(User(id=1, email="buyer@x", tier="free"))
|
||
await session.commit()
|
||
|
||
asyncio.run(_seed())
|
||
|
||
app = FastAPI()
|
||
app.include_router(stripe_router.router)
|
||
return TestClient(app), factory, sign_session(1)
|
||
|
||
|
||
def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET,
|
||
sig: str | None = None):
|
||
# Stripe's SDK requires a top-level `object: "event"` field to know
|
||
# this is a v1 webhook envelope — tests that omit it fail in
|
||
# construct_event before the signature check matters. We inject the
|
||
# default here so individual tests can stay terse.
|
||
body.setdefault("object", "event")
|
||
raw = json.dumps(body).encode("utf-8")
|
||
sig = sig if sig is not None else _stripe_sig(raw, secret)
|
||
return client.post(
|
||
"/api/stripe/webhook",
|
||
content=raw,
|
||
headers={"stripe-signature": sig, "content-type": "application/json"},
|
||
)
|
||
|
||
|
||
# --- signature gate --------------------------------------------------------
|
||
|
||
|
||
def test_webhook_rejects_bad_signature(tmp_path):
|
||
client, _, _ = _build_app(tmp_path)
|
||
raw = json.dumps({"id": "evt_x", "type": "invoice.paid",
|
||
"data": {"object": {}}}).encode("utf-8")
|
||
r = client.post(
|
||
"/api/stripe/webhook",
|
||
content=raw,
|
||
headers={
|
||
"stripe-signature": "t=0,v1=deadbeef",
|
||
"content-type": "application/json",
|
||
},
|
||
)
|
||
assert r.status_code == 401, r.text
|
||
|
||
|
||
def test_webhook_rejects_missing_signature(tmp_path):
|
||
client, _, _ = _build_app(tmp_path)
|
||
r = client.post(
|
||
"/api/stripe/webhook",
|
||
content=b"{}",
|
||
headers={"content-type": "application/json"},
|
||
)
|
||
assert r.status_code == 400, r.text
|
||
|
||
|
||
# --- happy paths -----------------------------------------------------------
|
||
|
||
|
||
def test_checkout_session_completed_flips_tier_to_paid(tmp_path):
|
||
client, factory, _ = _build_app(tmp_path)
|
||
body = {
|
||
"id": "evt_checkout_1",
|
||
"type": "checkout.session.completed",
|
||
"data": {
|
||
"object": {
|
||
"client_reference_id": "1",
|
||
"customer": "cus_abc",
|
||
"subscription": "sub_xyz",
|
||
}
|
||
},
|
||
}
|
||
r = _post_webhook(client, body=body)
|
||
assert r.status_code == 200, r.text
|
||
assert r.json()["status"] == "ok"
|
||
|
||
async def _check():
|
||
from sqlalchemy import select
|
||
from app.models import User
|
||
async with factory() as session:
|
||
u = (await session.execute(
|
||
select(User).where(User.id == 1)
|
||
)).scalar_one()
|
||
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
|
||
|
||
tier, cid, sid = asyncio.run(_check())
|
||
assert tier == "paid"
|
||
assert cid == "cus_abc"
|
||
assert sid == "sub_xyz"
|
||
|
||
|
||
def test_subscription_deleted_drops_tier_to_free(tmp_path):
|
||
client, factory, _ = _build_app(tmp_path)
|
||
# First, activate.
|
||
_post_webhook(client, body={
|
||
"id": "evt_act",
|
||
"type": "checkout.session.completed",
|
||
"data": {"object": {
|
||
"client_reference_id": "1",
|
||
"customer": "cus_abc",
|
||
"subscription": "sub_xyz",
|
||
}},
|
||
})
|
||
# Then, delete the subscription.
|
||
r = _post_webhook(client, body={
|
||
"id": "evt_del",
|
||
"type": "customer.subscription.deleted",
|
||
"data": {"object": {
|
||
"id": "sub_xyz",
|
||
"customer": "cus_abc",
|
||
"status": "canceled",
|
||
}},
|
||
})
|
||
assert r.status_code == 200, r.text
|
||
|
||
async def _check():
|
||
from sqlalchemy import select
|
||
from app.models import User
|
||
async with factory() as session:
|
||
u = (await session.execute(
|
||
select(User).where(User.id == 1)
|
||
)).scalar_one()
|
||
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
|
||
|
||
tier, cid, sid = asyncio.run(_check())
|
||
assert tier == "free"
|
||
# Customer linkage preserved so a future resub matches this row.
|
||
assert cid == "cus_abc"
|
||
assert sid is None
|
||
|
||
|
||
def test_subscription_trialing_stores_trial_end(tmp_path):
|
||
"""customer.subscription.created with status=trialing + trial_end
|
||
should grant paid AND persist the trial_end timestamp so the
|
||
settings page can show 'N days remaining'.
|
||
|
||
Realistic flow: checkout.session.completed fires first (linking
|
||
customer_id to user.id via client_reference_id), then
|
||
subscription.created fires moments later carrying trial_end."""
|
||
import datetime as _dt
|
||
client, factory, _ = _build_app(tmp_path)
|
||
# First: link the user to the Stripe customer via checkout.
|
||
_post_webhook(client, body={
|
||
"id": "evt_link",
|
||
"type": "checkout.session.completed",
|
||
"data": {"object": {
|
||
"client_reference_id": "1",
|
||
"customer": "cus_trial",
|
||
"subscription": "sub_trial",
|
||
}},
|
||
})
|
||
# Then: the subscription event carrying trial_end (12 days out).
|
||
trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc)
|
||
+ _dt.timedelta(days=12)).timestamp())
|
||
r = _post_webhook(client, body={
|
||
"id": "evt_trial",
|
||
"type": "customer.subscription.created",
|
||
"data": {"object": {
|
||
"id": "sub_trial",
|
||
"customer": "cus_trial",
|
||
"status": "trialing",
|
||
"trial_end": trial_end_ts,
|
||
}},
|
||
})
|
||
assert r.status_code == 200, r.text
|
||
|
||
async def _check():
|
||
from sqlalchemy import select
|
||
from app.models import User
|
||
async with factory() as session:
|
||
u = (await session.execute(
|
||
select(User).where(User.id == 1)
|
||
)).scalar_one()
|
||
return u.tier, u.stripe_trial_end_at
|
||
|
||
tier, end = asyncio.run(_check())
|
||
assert tier == "paid", "trial users must have paid features"
|
||
assert end is not None
|
||
# Stored value should match the trial_end we sent (within a second).
|
||
expected = _dt.datetime.fromtimestamp(trial_end_ts, tz=_dt.timezone.utc)
|
||
if end.tzinfo is None:
|
||
end = end.replace(tzinfo=_dt.timezone.utc)
|
||
assert abs((end - expected).total_seconds()) < 2
|
||
|
||
|
||
def test_subscription_active_clears_trial_end(tmp_path):
|
||
"""When the subscription transitions trialing -> active (day 15),
|
||
the trial_end marker should be cleared so settings stops showing
|
||
'trial — N days remaining'."""
|
||
import datetime as _dt
|
||
client, factory, _ = _build_app(tmp_path)
|
||
# Link the customer first via checkout, then plant a trial state.
|
||
_post_webhook(client, body={
|
||
"id": "evt_link2",
|
||
"type": "checkout.session.completed",
|
||
"data": {"object": {
|
||
"client_reference_id": "1",
|
||
"customer": "cus_t",
|
||
"subscription": "sub_t",
|
||
}},
|
||
})
|
||
trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc)
|
||
+ _dt.timedelta(days=12)).timestamp())
|
||
_post_webhook(client, body={
|
||
"id": "evt_t1",
|
||
"type": "customer.subscription.created",
|
||
"data": {"object": {
|
||
"id": "sub_t", "customer": "cus_t",
|
||
"status": "trialing", "trial_end": trial_end_ts,
|
||
}},
|
||
})
|
||
# Now transition to active.
|
||
_post_webhook(client, body={
|
||
"id": "evt_t2",
|
||
"type": "customer.subscription.updated",
|
||
"data": {"object": {
|
||
"id": "sub_t", "customer": "cus_t",
|
||
"status": "active",
|
||
}},
|
||
})
|
||
|
||
async def _check():
|
||
from sqlalchemy import select
|
||
from app.models import User
|
||
async with factory() as session:
|
||
u = (await session.execute(
|
||
select(User).where(User.id == 1)
|
||
)).scalar_one()
|
||
return u.tier, u.stripe_trial_end_at
|
||
|
||
tier, end = asyncio.run(_check())
|
||
assert tier == "paid"
|
||
assert end is None, "trial_end_at must be cleared once active"
|
||
|
||
|
||
def test_subscription_active_grants_paid(tmp_path):
|
||
"""customer.subscription.updated with status=active should also
|
||
grant paid — covers the case where checkout.session.completed
|
||
arrives after subscription.created and we want either to work."""
|
||
client, factory, _ = _build_app(tmp_path)
|
||
# Seed the linkage first via checkout (so customer_id is known).
|
||
_post_webhook(client, body={
|
||
"id": "evt_ck",
|
||
"type": "checkout.session.completed",
|
||
"data": {"object": {
|
||
"client_reference_id": "1",
|
||
"customer": "cus_abc",
|
||
"subscription": "sub_xyz",
|
||
}},
|
||
})
|
||
# Drop to free manually so we can prove the updated event re-grants.
|
||
async def _reset():
|
||
from sqlalchemy import update
|
||
from app.models import User
|
||
async with factory() as session:
|
||
await session.execute(
|
||
update(User).where(User.id == 1).values(tier="free")
|
||
)
|
||
await session.commit()
|
||
asyncio.run(_reset())
|
||
|
||
r = _post_webhook(client, body={
|
||
"id": "evt_upd",
|
||
"type": "customer.subscription.updated",
|
||
"data": {"object": {
|
||
"id": "sub_xyz",
|
||
"customer": "cus_abc",
|
||
"status": "active",
|
||
}},
|
||
})
|
||
assert r.status_code == 200
|
||
|
||
async def _check_tier():
|
||
from sqlalchemy import select
|
||
from app.models import User
|
||
async with factory() as session:
|
||
return (await session.execute(
|
||
select(User.tier).where(User.id == 1)
|
||
)).scalar_one()
|
||
assert asyncio.run(_check_tier()) == "paid"
|
||
|
||
|
||
# --- idempotency + unknown ------------------------------------------------
|
||
|
||
|
||
def test_replayed_event_id_is_a_noop(tmp_path):
|
||
client, factory, _ = _build_app(tmp_path)
|
||
body = {
|
||
"id": "evt_dup",
|
||
"type": "checkout.session.completed",
|
||
"data": {"object": {
|
||
"client_reference_id": "1",
|
||
"customer": "cus_abc",
|
||
"subscription": "sub_xyz",
|
||
}},
|
||
}
|
||
r1 = _post_webhook(client, body=body)
|
||
r2 = _post_webhook(client, body=body)
|
||
assert r1.json()["status"] == "ok"
|
||
assert r2.json()["status"] == "duplicate"
|
||
|
||
async def _count_rows():
|
||
from sqlalchemy import select, func
|
||
from app.models import StripeEvent
|
||
async with factory() as session:
|
||
n = (await session.execute(
|
||
select(func.count(StripeEvent.id))
|
||
.where(StripeEvent.event_id == "evt_dup")
|
||
)).scalar_one()
|
||
return n
|
||
assert asyncio.run(_count_rows()) == 1
|
||
|
||
|
||
def test_unknown_event_is_acked(tmp_path):
|
||
client, _, _ = _build_app(tmp_path)
|
||
r = _post_webhook(client, body={
|
||
"id": "evt_unknown",
|
||
"type": "product.something.new",
|
||
"data": {"object": {}},
|
||
})
|
||
assert r.status_code == 200
|
||
assert r.json()["status"] == "ignored"
|
||
|
||
|
||
# --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------
|
||
|
||
|
||
def _fake_checkout_client(asserter):
|
||
"""Build a fake Stripe client whose checkout.sessions.create calls
|
||
the supplied asserter on the params dict and returns a stub URL."""
|
||
fake_session = SimpleNamespace(
|
||
id="cs_test_123", url="https://checkout.stripe.com/test",
|
||
)
|
||
|
||
class _FakeSessions:
|
||
@staticmethod
|
||
def create(params): # noqa: ANN001
|
||
asserter(params)
|
||
return fake_session
|
||
|
||
class _FakeCheckout:
|
||
sessions = _FakeSessions()
|
||
|
||
class _FakeClient:
|
||
checkout = _FakeCheckout()
|
||
|
||
return _FakeClient()
|
||
|
||
|
||
def test_checkout_monthly_has_no_trial_and_no_stripe_consent(tmp_path):
|
||
"""Monthly checkout must NOT carry a free trial (£7 × 14 days would
|
||
halve cycle-1 revenue) AND must NOT use Stripe's account-wide
|
||
consent_collection — the Reg-36 waiver is collected on /pricing
|
||
so each product can use its own Terms URL."""
|
||
client, _, session_cookie = _build_app(tmp_path)
|
||
|
||
def asserter(params):
|
||
assert params["mode"] == "subscription"
|
||
assert params["line_items"][0]["price"] == _PRICE_MONTHLY
|
||
assert params["client_reference_id"] == "1"
|
||
assert params["customer_email"] == "buyer@x"
|
||
assert "subscription_data" not in params, "no trial on monthly"
|
||
assert "consent_collection" not in params, (
|
||
"consent is collected on /pricing, not via Stripe's account-wide setting"
|
||
)
|
||
|
||
with patch("app.routers.stripe_billing._stripe_client",
|
||
return_value=_fake_checkout_client(asserter)):
|
||
r = client.post(
|
||
"/api/stripe/checkout",
|
||
json={"cadence": "monthly"},
|
||
cookies={"cassandra_session": session_cookie},
|
||
)
|
||
assert r.status_code == 200, r.text
|
||
assert r.json()["url"] == "https://checkout.stripe.com/test"
|
||
|
||
|
||
def test_checkout_annual_uses_trial_not_consent_collection(tmp_path):
|
||
"""Annual checkout gets the 14-day free trial (substitutes for the
|
||
statutory cooling-off right; no money moves during the trial)."""
|
||
client, _, session_cookie = _build_app(tmp_path)
|
||
|
||
def asserter(params):
|
||
assert params["mode"] == "subscription"
|
||
assert params["line_items"][0]["price"] == _PRICE_ANNUAL
|
||
assert params["subscription_data"]["trial_period_days"] == 14
|
||
assert "consent_collection" not in params, "annual relies on trial, not consent"
|
||
|
||
with patch("app.routers.stripe_billing._stripe_client",
|
||
return_value=_fake_checkout_client(asserter)):
|
||
r = client.post(
|
||
"/api/stripe/checkout",
|
||
json={"cadence": "annual"},
|
||
cookies={"cassandra_session": session_cookie},
|
||
)
|
||
assert r.status_code == 200, r.text
|
||
assert r.json()["url"] == "https://checkout.stripe.com/test"
|
||
|
||
|
||
def test_checkout_endpoint_requires_login(tmp_path):
|
||
client, _, _ = _build_app(tmp_path)
|
||
r = client.post("/api/stripe/checkout", json={"cadence": "monthly"})
|
||
# No session cookie → require_auth bounces with 401.
|
||
assert r.status_code == 401, r.text
|