read.markets/tests/test_stripe_billing.py
Giorgio Gilestro 83995e96c8 stripe: detect buyer currency at checkout (GBP/USD/EUR)
Pass `currency` to Stripe checkout for first-time buyers so Stripe
picks the matching `currency_options` rate configured on the Price
in the Dashboard (multi-currency Prices: one Price, per-currency
unit_amount). Operator configures the rates on existing Prices
prod_UaZ0xCpCboUGCN/price_*; this commit is the application-side
signal.

Currency precedence: explicit request body > Cloudflare cf-ipcountry
header > Accept-Language locale > GBP fallback. Only honoured when
the user has no stripe_customer_id yet — Stripe locks currency to
the customer record at first checkout, so existing customers keep
their original currency (they can switch via the portal).

Adds 4 tests: sniffed currency on new customer, body override beats
sniff, currency omitted for existing customer, and unit-tests for
the sniffing fallback chain.
2026-05-28 12:42:40 +02:00

559 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Stripe billing endpoints: signature verification, idempotency, tier
flips, and checkout creation.
Same integration-style scaffold as test_polar_webhook.py — real router
over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal
sessions.create) are mocked so the suite never makes a real HTTP call.
"""
from __future__ import annotations
import asyncio
import hashlib
import hmac
import json
import time
from types import SimpleNamespace
from unittest.mock import patch
import pytest
_API_KEY = "sk_test_dummy_for_unit_tests"
_WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests"
_PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx"
_PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx"
def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str:
"""Produce a Stripe-Signature header matching the bytes signed.
Format: `t=<ts>,v1=<hex hmac sha256>` over `<ts>.<body>`."""
ts = ts if ts is not None else int(time.time())
signed = f"{ts}.{body.decode('utf-8')}"
mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"),
hashlib.sha256).hexdigest()
return f"t={ts},v1={mac}"
def _build_app(tmp_path):
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app import db as db_mod
from app.auth import sign_session
from app.config import get_settings
from app.db import Base
from app.models import User
from app.routers import stripe_billing as stripe_router
s = get_settings()
s.STRIPE_API_KEY = _API_KEY # type: ignore[misc]
s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc]
s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc]
s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc]
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db")
factory = async_sessionmaker(engine, expire_on_commit=False)
db_mod._engine = engine
db_mod._session_factory = factory
async def _seed():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with factory() as session:
session.add(User(id=1, email="buyer@x", tier="free"))
await session.commit()
asyncio.run(_seed())
app = FastAPI()
app.include_router(stripe_router.router)
return TestClient(app), factory, sign_session(1)
def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET,
sig: str | None = None):
# Stripe's SDK requires a top-level `object: "event"` field to know
# this is a v1 webhook envelope — tests that omit it fail in
# construct_event before the signature check matters. We inject the
# default here so individual tests can stay terse.
body.setdefault("object", "event")
raw = json.dumps(body).encode("utf-8")
sig = sig if sig is not None else _stripe_sig(raw, secret)
return client.post(
"/api/stripe/webhook",
content=raw,
headers={"stripe-signature": sig, "content-type": "application/json"},
)
# --- signature gate --------------------------------------------------------
def test_webhook_rejects_bad_signature(tmp_path):
client, _, _ = _build_app(tmp_path)
raw = json.dumps({"id": "evt_x", "type": "invoice.paid",
"data": {"object": {}}}).encode("utf-8")
r = client.post(
"/api/stripe/webhook",
content=raw,
headers={
"stripe-signature": "t=0,v1=deadbeef",
"content-type": "application/json",
},
)
assert r.status_code == 401, r.text
def test_webhook_rejects_missing_signature(tmp_path):
client, _, _ = _build_app(tmp_path)
r = client.post(
"/api/stripe/webhook",
content=b"{}",
headers={"content-type": "application/json"},
)
assert r.status_code == 400, r.text
# --- happy paths -----------------------------------------------------------
def test_checkout_session_completed_flips_tier_to_paid(tmp_path):
client, factory, _ = _build_app(tmp_path)
body = {
"id": "evt_checkout_1",
"type": "checkout.session.completed",
"data": {
"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}
},
}
r = _post_webhook(client, body=body)
assert r.status_code == 200, r.text
assert r.json()["status"] == "ok"
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
tier, cid, sid = asyncio.run(_check())
assert tier == "paid"
assert cid == "cus_abc"
assert sid == "sub_xyz"
def test_subscription_deleted_drops_tier_to_free(tmp_path):
client, factory, _ = _build_app(tmp_path)
# First, activate.
_post_webhook(client, body={
"id": "evt_act",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
})
# Then, delete the subscription.
r = _post_webhook(client, body={
"id": "evt_del",
"type": "customer.subscription.deleted",
"data": {"object": {
"id": "sub_xyz",
"customer": "cus_abc",
"status": "canceled",
}},
})
assert r.status_code == 200, r.text
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
tier, cid, sid = asyncio.run(_check())
assert tier == "free"
# Customer linkage preserved so a future resub matches this row.
assert cid == "cus_abc"
assert sid is None
def test_subscription_trialing_stores_trial_end(tmp_path):
"""customer.subscription.created with status=trialing + trial_end
should grant paid AND persist the trial_end timestamp so the
settings page can show 'N days remaining'.
Realistic flow: checkout.session.completed fires first (linking
customer_id to user.id via client_reference_id), then
subscription.created fires moments later carrying trial_end."""
import datetime as _dt
client, factory, _ = _build_app(tmp_path)
# First: link the user to the Stripe customer via checkout.
_post_webhook(client, body={
"id": "evt_link",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_trial",
"subscription": "sub_trial",
}},
})
# Then: the subscription event carrying trial_end (12 days out).
trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc)
+ _dt.timedelta(days=12)).timestamp())
r = _post_webhook(client, body={
"id": "evt_trial",
"type": "customer.subscription.created",
"data": {"object": {
"id": "sub_trial",
"customer": "cus_trial",
"status": "trialing",
"trial_end": trial_end_ts,
}},
})
assert r.status_code == 200, r.text
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_trial_end_at
tier, end = asyncio.run(_check())
assert tier == "paid", "trial users must have paid features"
assert end is not None
# Stored value should match the trial_end we sent (within a second).
expected = _dt.datetime.fromtimestamp(trial_end_ts, tz=_dt.timezone.utc)
if end.tzinfo is None:
end = end.replace(tzinfo=_dt.timezone.utc)
assert abs((end - expected).total_seconds()) < 2
def test_subscription_active_clears_trial_end(tmp_path):
"""When the subscription transitions trialing -> active (day 15),
the trial_end marker should be cleared so settings stops showing
'trial — N days remaining'."""
import datetime as _dt
client, factory, _ = _build_app(tmp_path)
# Link the customer first via checkout, then plant a trial state.
_post_webhook(client, body={
"id": "evt_link2",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_t",
"subscription": "sub_t",
}},
})
trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc)
+ _dt.timedelta(days=12)).timestamp())
_post_webhook(client, body={
"id": "evt_t1",
"type": "customer.subscription.created",
"data": {"object": {
"id": "sub_t", "customer": "cus_t",
"status": "trialing", "trial_end": trial_end_ts,
}},
})
# Now transition to active.
_post_webhook(client, body={
"id": "evt_t2",
"type": "customer.subscription.updated",
"data": {"object": {
"id": "sub_t", "customer": "cus_t",
"status": "active",
}},
})
async def _check():
from sqlalchemy import select
from app.models import User
async with factory() as session:
u = (await session.execute(
select(User).where(User.id == 1)
)).scalar_one()
return u.tier, u.stripe_trial_end_at
tier, end = asyncio.run(_check())
assert tier == "paid"
assert end is None, "trial_end_at must be cleared once active"
def test_subscription_active_grants_paid(tmp_path):
"""customer.subscription.updated with status=active should also
grant paid — covers the case where checkout.session.completed
arrives after subscription.created and we want either to work."""
client, factory, _ = _build_app(tmp_path)
# Seed the linkage first via checkout (so customer_id is known).
_post_webhook(client, body={
"id": "evt_ck",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
})
# Drop to free manually so we can prove the updated event re-grants.
async def _reset():
from sqlalchemy import update
from app.models import User
async with factory() as session:
await session.execute(
update(User).where(User.id == 1).values(tier="free")
)
await session.commit()
asyncio.run(_reset())
r = _post_webhook(client, body={
"id": "evt_upd",
"type": "customer.subscription.updated",
"data": {"object": {
"id": "sub_xyz",
"customer": "cus_abc",
"status": "active",
}},
})
assert r.status_code == 200
async def _check_tier():
from sqlalchemy import select
from app.models import User
async with factory() as session:
return (await session.execute(
select(User.tier).where(User.id == 1)
)).scalar_one()
assert asyncio.run(_check_tier()) == "paid"
# --- idempotency + unknown ------------------------------------------------
def test_replayed_event_id_is_a_noop(tmp_path):
client, factory, _ = _build_app(tmp_path)
body = {
"id": "evt_dup",
"type": "checkout.session.completed",
"data": {"object": {
"client_reference_id": "1",
"customer": "cus_abc",
"subscription": "sub_xyz",
}},
}
r1 = _post_webhook(client, body=body)
r2 = _post_webhook(client, body=body)
assert r1.json()["status"] == "ok"
assert r2.json()["status"] == "duplicate"
async def _count_rows():
from sqlalchemy import select, func
from app.models import StripeEvent
async with factory() as session:
n = (await session.execute(
select(func.count(StripeEvent.id))
.where(StripeEvent.event_id == "evt_dup")
)).scalar_one()
return n
assert asyncio.run(_count_rows()) == 1
def test_unknown_event_is_acked(tmp_path):
client, _, _ = _build_app(tmp_path)
r = _post_webhook(client, body={
"id": "evt_unknown",
"type": "product.something.new",
"data": {"object": {}},
})
assert r.status_code == 200
assert r.json()["status"] == "ignored"
# --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------
def _fake_checkout_client(asserter):
"""Build a fake Stripe client whose checkout.sessions.create calls
the supplied asserter on the params dict and returns a stub URL."""
fake_session = SimpleNamespace(
id="cs_test_123", url="https://checkout.stripe.com/test",
)
class _FakeSessions:
@staticmethod
def create(params): # noqa: ANN001
asserter(params)
return fake_session
class _FakeCheckout:
sessions = _FakeSessions()
class _FakeClient:
checkout = _FakeCheckout()
return _FakeClient()
def test_checkout_monthly_has_no_trial_and_no_stripe_consent(tmp_path):
"""Monthly checkout must NOT carry a free trial (£7 × 14 days would
halve cycle-1 revenue) AND must NOT use Stripe's account-wide
consent_collection — the Reg-36 waiver is collected on /pricing
so each product can use its own Terms URL."""
client, _, session_cookie = _build_app(tmp_path)
def asserter(params):
assert params["mode"] == "subscription"
assert params["line_items"][0]["price"] == _PRICE_MONTHLY
assert params["client_reference_id"] == "1"
assert params["customer_email"] == "buyer@x"
assert "subscription_data" not in params, "no trial on monthly"
assert "consent_collection" not in params, (
"consent is collected on /pricing, not via Stripe's account-wide setting"
)
with patch("app.routers.stripe_billing._stripe_client",
return_value=_fake_checkout_client(asserter)):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "monthly"},
cookies={"cassandra_session": session_cookie},
)
assert r.status_code == 200, r.text
assert r.json()["url"] == "https://checkout.stripe.com/test"
def test_checkout_annual_uses_trial_not_consent_collection(tmp_path):
"""Annual checkout gets the 14-day free trial (substitutes for the
statutory cooling-off right; no money moves during the trial)."""
client, _, session_cookie = _build_app(tmp_path)
def asserter(params):
assert params["mode"] == "subscription"
assert params["line_items"][0]["price"] == _PRICE_ANNUAL
assert params["subscription_data"]["trial_period_days"] == 14
assert "consent_collection" not in params, "annual relies on trial, not consent"
with patch("app.routers.stripe_billing._stripe_client",
return_value=_fake_checkout_client(asserter)):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "annual"},
cookies={"cassandra_session": session_cookie},
)
assert r.status_code == 200, r.text
assert r.json()["url"] == "https://checkout.stripe.com/test"
def test_checkout_endpoint_requires_login(tmp_path):
client, _, _ = _build_app(tmp_path)
r = client.post("/api/stripe/checkout", json={"cadence": "monthly"})
# No session cookie → require_auth bounces with 401.
assert r.status_code == 401, r.text
def test_checkout_passes_sniffed_currency_for_new_customer(tmp_path):
"""First-time buyer (no stripe_customer_id yet) gets the currency
sniffed from the request. CF-IPCountry=US → 'usd', and Stripe will
look up the USD currency_option on the Price."""
client, _, session_cookie = _build_app(tmp_path)
def asserter(params):
assert params["currency"] == "usd"
with patch("app.routers.stripe_billing._stripe_client",
return_value=_fake_checkout_client(asserter)):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "monthly"},
cookies={"cassandra_session": session_cookie},
headers={"cf-ipcountry": "US"},
)
assert r.status_code == 200, r.text
def test_checkout_body_currency_overrides_sniff(tmp_path):
"""Explicit `currency` in the request body beats header sniffing —
lets a UK-based buyer choose EUR if they want to."""
client, _, session_cookie = _build_app(tmp_path)
def asserter(params):
assert params["currency"] == "eur"
with patch("app.routers.stripe_billing._stripe_client",
return_value=_fake_checkout_client(asserter)):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "monthly", "currency": "eur"},
cookies={"cassandra_session": session_cookie},
headers={"cf-ipcountry": "GB"},
)
assert r.status_code == 200, r.text
def test_checkout_omits_currency_for_existing_customer(tmp_path):
"""Existing customer: Stripe locked their currency at first
checkout, so passing `currency` again would error. Verify we omit
it (and also use the existing `customer` ref instead of
customer_email)."""
import asyncio
from app.models import User
client, factory, session_cookie = _build_app(tmp_path)
async def _link():
async with factory() as s:
u = await s.get(User, 1)
u.stripe_customer_id = "cus_existing_xxxxxxxxxxxxxx"
await s.commit()
asyncio.run(_link())
def asserter(params):
assert "currency" not in params, (
"currency must not be passed once a customer exists — "
"Stripe rejects mismatches against the locked customer currency"
)
assert params["customer"] == "cus_existing_xxxxxxxxxxxxxx"
with patch("app.routers.stripe_billing._stripe_client",
return_value=_fake_checkout_client(asserter)):
r = client.post(
"/api/stripe/checkout",
json={"cadence": "monthly", "currency": "usd"},
cookies={"cassandra_session": session_cookie},
headers={"cf-ipcountry": "US"},
)
assert r.status_code == 200, r.text
def test_sniff_currency_fallback_chain():
"""Unit-test the header-sniffing helper: CF country wins, then
Accept-Language exact, then language-only, then GBP default."""
from types import SimpleNamespace
from app.routers.stripe_billing import _sniff_currency
def _req(headers):
return SimpleNamespace(headers=headers)
assert _sniff_currency(_req({"cf-ipcountry": "DE"})) == "eur"
assert _sniff_currency(_req({"cf-ipcountry": "us"})) == "usd" # case-insensitive
assert _sniff_currency(_req({"accept-language": "fr-FR,fr;q=0.9"})) == "eur"
assert _sniff_currency(_req({"accept-language": "en-US,en;q=0.5"})) == "usd"
assert _sniff_currency(_req({"accept-language": "ja,ja-JP;q=0.5"})) == "gbp"
assert _sniff_currency(_req({})) == "gbp"