read.markets/tests/test_stripe_billing.py
Giorgio Gilestro a07fd144ea stripe: per-cadence cooling-off + manage-subscription button
Bundles three related pieces that came out of the operator's first
end-to-end test of the paid flow:

1. Manage subscription button on /settings (paid users with a real
   Stripe sub — i.e. not credit-granted access). POSTs to the existing
   /api/stripe/portal endpoint; Stripe-hosted customer portal handles
   card updates, cancellation, monthly↔annual switch, invoice history.
   Replaces the stale "Paid features unlock with Paddle (D.3) or
   invite credits" hint for free users with a live link to /pricing.

2. Per-cadence cooling-off treatment:

   - **Annual £70**: 14-day free trial via
     subscription_data.trial_period_days=14. No money moves during
     the trial, so the CCR 2013 14-day refund question doesn't arise
     (nothing paid = nothing to refund). Card is still required at
     checkout so Stripe can charge on day 15.

   - **Monthly £7**: bills immediately. A 14-day trial there would
     give away ~50% of cycle one. Instead, /pricing now carries a
     required tick-box above the Subscribe buttons (subscribe stays
     disabled until checked) — by ticking, the user expressly
     consents to begin performance immediately and acknowledges that
     this extinguishes their statutory 14-day right under Reg 36
     CCR 2013. Consent collected on our own page (not via Stripe's
     account-wide consent_collection.terms_of_service) so each
     product can keep its own Terms URL as we add more.

3. T&C §6 clause 1 split into 1a (annual / trial substitute) +
   1b (monthly / Reg 36 waiver via on-page tick-box). Clause 2
   (post-cooling-off cancellation) unchanged.

Settings page shows "Free trial — N days remaining" while the
sub is in `trialing` status, falling back to "Paid subscription
active." once it transitions to active. Countdown is computed
server-side from User.stripe_trial_end_at (new column, migration
0020) populated by the subscription.created/updated webhook from
the Stripe trial_end timestamp; cleared on the trialing→active
transition and on revoke.

Drive-by: fixed a structlog kwarg-name collision on
`log.warning(..., event=event_type, ...)` in both polar_webhook.py
and stripe_billing.py — `event` is structlog's positional event
name and "got multiple values" crashed the user-not-found log
path. Renamed to `event_type=` everywhere it appeared. Caught by
the new trialing-stores-trial-end test.

Tests
- 4 new in test_stripe_billing.py covering monthly (no trial, no
  consent_collection), annual (trial, no consent), trialing stores
  trial_end, trialing→active clears trial_end.
- 1 existing test renamed + reworked for the consent split.
- Full suite: 224 passed, 5 skipped.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-26 20:06:19 +02:00

465 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Stripe billing endpoints: signature verification, idempotency, tier
flips, and checkout creation.
Same integration-style scaffold as test_polar_webhook.py — real router
over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal
sessions.create) are mocked so the suite never makes a real HTTP call.
"""
from __future__ import annotations
import asyncio
import hashlib
import hmac
import json
import time
from types import SimpleNamespace
from unittest.mock import patch
import pytest
_API_KEY = "sk_test_dummy_for_unit_tests"
_WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests"
_PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx"
_PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx"
def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str:
"""Produce a Stripe-Signature header matching the bytes signed.
Format: `t=<ts>,v1=<hex hmac sha256>` over `<ts>.<body>`."""
ts = ts if ts is not None else int(time.time())
signed = f"{ts}.{body.decode('utf-8')}"
mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"),
hashlib.sha256).hexdigest()
return f"t={ts},v1={mac}"
def _build_app(tmp_path):
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.auth import sign_session
from app.config import get_settings
from app.db import Base
from app.models import User
from app.routers import stripe_billing as stripe_router
s = get_settings()
s.STRIPE_API_KEY = _API_KEY # type: ignore[misc]
s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc]
s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc]
s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc]
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _seed():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with factory() as session:
session.add(User(id=1, email="buyer@x", tier="free"))
await session.commit()
asyncio.run(_seed())
app = FastAPI()
app.include_router(stripe_router.router)
return TestClient(app), factory, sign_session(1)
def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET,
sig: str | None = None):
# Stripe's SDK requires a top-level `object: "event"` field to know
# this is a v1 webhook envelope — tests that omit it fail in
# construct_event before the signature check matters. We inject the
# default here so individual tests can stay terse.
body.setdefault("object", "event")
raw = json.dumps(body).encode("utf-8")
sig = sig if sig is not None else _stripe_sig(raw, secret)
return client.post(
"/api/stripe/webhook",
content=raw,
headers={"stripe-signature": sig, "content-type": "application/json"},
)
# --- signature gate --------------------------------------------------------
def test_webhook_rejects_bad_signature(tmp_path):
client, _, _ = _build_app(tmp_path)
raw = json.dumps({"id": "evt_x", "type": "invoice.paid",
"data": {"object": {}}}).encode("utf-8")
r = client.post(
"/api/stripe/webhook",
content=raw,
headers={
"stripe-signature": "t=0,v1=deadbeef",
"content-type": "application/json",
},
)
assert r.status_code == 401, r.text
def test_webhook_rejects_missing_signature(tmp_path):
client, _, _ = _build_app(tmp_path)
r = client.post(
"/api/stripe/webhook",
content=b"{}",
headers={"content-type": "application/json"},
)
assert r.status_code == 400, r.text
# --- happy paths -----------------------------------------------------------
def test_checkout_session_completed_flips_tier_to_paid(tmp_path):
client, factory, _ = _build_app(tmp_path)
body = {
"id": "evt_checkout_1",
"type": "checkout.session.completed",
"data": {
"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}
},
}
r = _post_webhook(client, body=body)
assert r.status_code == 200, r.text
assert r.json()["status"] == "ok"
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
tier, cid, sid = asyncio.run(_check())
assert tier == "paid"
assert cid == "cus_abc"
assert sid == "sub_xyz"
def test_subscription_deleted_drops_tier_to_free(tmp_path):
client, factory, _ = _build_app(tmp_path)
# First, activate.
_post_webhook(client, body={
"id": "evt_act",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
})
# Then, delete the subscription.
r = _post_webhook(client, body={
"id": "evt_del",
"type": "customer.subscription.deleted",
"data": {"object": {
"id": "sub_xyz",
"customer": "cus_abc",
"status": "canceled",
}},
})
assert r.status_code == 200, r.text
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
tier, cid, sid = asyncio.run(_check())
assert tier == "free"
# Customer linkage preserved so a future resub matches this row.
assert cid == "cus_abc"
assert sid is None
def test_subscription_trialing_stores_trial_end(tmp_path):
"""customer.subscription.created with status=trialing + trial_end
should grant paid AND persist the trial_end timestamp so the
settings page can show 'N days remaining'.
Realistic flow: checkout.session.completed fires first (linking
customer_id to user.id via client_reference_id), then
subscription.created fires moments later carrying trial_end."""
import datetime as _dt
client, factory, _ = _build_app(tmp_path)
# First: link the user to the Stripe customer via checkout.
_post_webhook(client, body={
"id": "evt_link",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_trial",
"subscription": "sub_trial",
}},
})
# Then: the subscription event carrying trial_end (12 days out).
trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc)
+ _dt.timedelta(days=12)).timestamp())
r = _post_webhook(client, body={
"id": "evt_trial",
"type": "customer.subscription.created",
"data": {"object": {
"id": "sub_trial",
"customer": "cus_trial",
"status": "trialing",
"trial_end": trial_end_ts,
}},
})
assert r.status_code == 200, r.text
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_trial_end_at
tier, end = asyncio.run(_check())
assert tier == "paid", "trial users must have paid features"
assert end is not None
# Stored value should match the trial_end we sent (within a second).
expected = _dt.datetime.fromtimestamp(trial_end_ts, tz=_dt.timezone.utc)
if end.tzinfo is None:
end = end.replace(tzinfo=_dt.timezone.utc)
assert abs((end - expected).total_seconds()) < 2
def test_subscription_active_clears_trial_end(tmp_path):
"""When the subscription transitions trialing -> active (day 15),
the trial_end marker should be cleared so settings stops showing
'trial — N days remaining'."""
import datetime as _dt
client, factory, _ = _build_app(tmp_path)
# Link the customer first via checkout, then plant a trial state.
_post_webhook(client, body={
"id": "evt_link2",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_t",
"subscription": "sub_t",
}},
})
trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc)
+ _dt.timedelta(days=12)).timestamp())
_post_webhook(client, body={
"id": "evt_t1",
"type": "customer.subscription.created",
"data": {"object": {
"id": "sub_t", "customer": "cus_t",
"status": "trialing", "trial_end": trial_end_ts,
}},
})
# Now transition to active.
_post_webhook(client, body={
"id": "evt_t2",
"type": "customer.subscription.updated",
"data": {"object": {
"id": "sub_t", "customer": "cus_t",
"status": "active",
}},
})
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_trial_end_at
tier, end = asyncio.run(_check())
assert tier == "paid"
assert end is None, "trial_end_at must be cleared once active"
def test_subscription_active_grants_paid(tmp_path):
"""customer.subscription.updated with status=active should also
grant paid — covers the case where checkout.session.completed
arrives after subscription.created and we want either to work."""
client, factory, _ = _build_app(tmp_path)
# Seed the linkage first via checkout (so customer_id is known).
_post_webhook(client, body={
"id": "evt_ck",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
})
# Drop to free manually so we can prove the updated event re-grants.
async def _reset():
from sqlalchemy import update
from app.models import User
async with factory() as session:
await session.execute(
update(User).where(User.id == 1).values(tier="free")
)
await session.commit()
asyncio.run(_reset())
r = _post_webhook(client, body={
"id": "evt_upd",
"type": "customer.subscription.updated",
"data": {"object": {
"id": "sub_xyz",
"customer": "cus_abc",
"status": "active",
}},
})
assert r.status_code == 200
async def _check_tier():
from sqlalchemy import select
from app.models import User
async with factory() as session:
return (await session.execute(
select(User.tier).where(User.id == 1)
)).scalar_one()
assert asyncio.run(_check_tier()) == "paid"
# --- idempotency + unknown ------------------------------------------------
def test_replayed_event_id_is_a_noop(tmp_path):
client, factory, _ = _build_app(tmp_path)
body = {
"id": "evt_dup",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
}
r1 = _post_webhook(client, body=body)
r2 = _post_webhook(client, body=body)
assert r1.json()["status"] == "ok"
assert r2.json()["status"] == "duplicate"
async def _count_rows():
from sqlalchemy import select, func
from app.models import StripeEvent
async with factory() as session:
n = (await session.execute(
select(func.count(StripeEvent.id))
.where(StripeEvent.event_id == "evt_dup")
)).scalar_one()
return n
assert asyncio.run(_count_rows()) == 1
def test_unknown_event_is_acked(tmp_path):
client, _, _ = _build_app(tmp_path)
r = _post_webhook(client, body={
"id": "evt_unknown",
"type": "product.something.new",
"data": {"object": {}},
})
assert r.status_code == 200
assert r.json()["status"] == "ignored"
# --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------
def _fake_checkout_client(asserter):
"""Build a fake Stripe client whose checkout.sessions.create calls
the supplied asserter on the params dict and returns a stub URL."""
fake_session = SimpleNamespace(
id="cs_test_123", url="https://checkout.stripe.com/test",
)
class _FakeSessions:
@staticmethod
def create(params): # noqa: ANN001
asserter(params)
return fake_session
class _FakeCheckout:
sessions = _FakeSessions()
class _FakeClient:
checkout = _FakeCheckout()
return _FakeClient()
def test_checkout_monthly_has_no_trial_and_no_stripe_consent(tmp_path):
"""Monthly checkout must NOT carry a free trial (£7 × 14 days would
halve cycle-1 revenue) AND must NOT use Stripe's account-wide
consent_collection — the Reg-36 waiver is collected on /pricing
so each product can use its own Terms URL."""
client, _, session_cookie = _build_app(tmp_path)
def asserter(params):
assert params["mode"] == "subscription"
assert params["line_items"][0]["price"] == _PRICE_MONTHLY
assert params["client_reference_id"] == "1"
assert params["customer_email"] == "buyer@x"
assert "subscription_data" not in params, "no trial on monthly"
assert "consent_collection" not in params, (
"consent is collected on /pricing, not via Stripe's account-wide setting"
)
with patch("app.routers.stripe_billing._stripe_client",
return_value=_fake_checkout_client(asserter)):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "monthly"},
cookies={"cassandra_session": session_cookie},
)
assert r.status_code == 200, r.text
assert r.json()["url"] == "https://checkout.stripe.com/test"
def test_checkout_annual_uses_trial_not_consent_collection(tmp_path):
"""Annual checkout gets the 14-day free trial (substitutes for the
statutory cooling-off right; no money moves during the trial)."""
client, _, session_cookie = _build_app(tmp_path)
def asserter(params):
assert params["mode"] == "subscription"
assert params["line_items"][0]["price"] == _PRICE_ANNUAL
assert params["subscription_data"]["trial_period_days"] == 14
assert "consent_collection" not in params, "annual relies on trial, not consent"
with patch("app.routers.stripe_billing._stripe_client",
return_value=_fake_checkout_client(asserter)):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "annual"},
cookies={"cassandra_session": session_cookie},
)
assert r.status_code == 200, r.text
assert r.json()["url"] == "https://checkout.stripe.com/test"
def test_checkout_endpoint_requires_login(tmp_path):
client, _, _ = _build_app(tmp_path)
r = client.post("/api/stripe/checkout", json={"cadence": "monthly"})
# No session cookie → require_auth bounces with 401.
assert r.status_code == 401, r.text