"""Stripe billing endpoints: signature verification, idempotency, tier flips, and checkout creation. Same integration-style scaffold as test_polar_webhook.py — real router over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal sessions.create) are mocked so the suite never makes a real HTTP call. """ from __future__ import annotations import asyncio import hashlib import hmac import json import time from types import SimpleNamespace from unittest.mock import patch import pytest _API_KEY = "sk_test_dummy_for_unit_tests" _WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests" _PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx" _PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx" def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str: """Produce a Stripe-Signature header matching the bytes signed. Format: `t=,v1=` over `.`.""" ts = ts if ts is not None else int(time.time()) signed = f"{ts}.{body.decode('utf-8')}" mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"), hashlib.sha256).hexdigest() return f"t={ts},v1={mac}" def _build_app(tmp_path): from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from app import db as db_mod from app.auth import sign_session from app.config import get_settings from app.db import Base from app.models import User from app.routers import stripe_billing as stripe_router s = get_settings() s.STRIPE_API_KEY = _API_KEY # type: ignore[misc] s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc] s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc] s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc] engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db") factory = async_sessionmaker(engine, expire_on_commit=False) db_mod._engine = engine db_mod._session_factory = factory async def _seed(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async with factory() as session: session.add(User(id=1, email="buyer@x", tier="free")) await session.commit() asyncio.run(_seed()) app = FastAPI() app.include_router(stripe_router.router) return TestClient(app), factory, sign_session(1) def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET, sig: str | None = None): # Stripe's SDK requires a top-level `object: "event"` field to know # this is a v1 webhook envelope — tests that omit it fail in # construct_event before the signature check matters. We inject the # default here so individual tests can stay terse. body.setdefault("object", "event") raw = json.dumps(body).encode("utf-8") sig = sig if sig is not None else _stripe_sig(raw, secret) return client.post( "/api/stripe/webhook", content=raw, headers={"stripe-signature": sig, "content-type": "application/json"}, ) # --- signature gate -------------------------------------------------------- def test_webhook_rejects_bad_signature(tmp_path): client, _, _ = _build_app(tmp_path) raw = json.dumps({"id": "evt_x", "type": "invoice.paid", "data": {"object": {}}}).encode("utf-8") r = client.post( "/api/stripe/webhook", content=raw, headers={ "stripe-signature": "t=0,v1=deadbeef", "content-type": "application/json", }, ) assert r.status_code == 401, r.text def test_webhook_rejects_missing_signature(tmp_path): client, _, _ = _build_app(tmp_path) r = client.post( "/api/stripe/webhook", content=b"{}", headers={"content-type": "application/json"}, ) assert r.status_code == 400, r.text # --- happy paths ----------------------------------------------------------- def test_checkout_session_completed_flips_tier_to_paid(tmp_path): client, factory, _ = _build_app(tmp_path) body = { "id": "evt_checkout_1", "type": "checkout.session.completed", "data": { "object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", } }, } r = _post_webhook(client, body=body) assert r.status_code == 200, r.text assert r.json()["status"] == "ok" async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.stripe_customer_id, u.stripe_subscription_id tier, cid, sid = asyncio.run(_check()) assert tier == "paid" assert cid == "cus_abc" assert sid == "sub_xyz" def test_subscription_deleted_drops_tier_to_free(tmp_path): client, factory, _ = _build_app(tmp_path) # First, activate. _post_webhook(client, body={ "id": "evt_act", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", }}, }) # Then, delete the subscription. r = _post_webhook(client, body={ "id": "evt_del", "type": "customer.subscription.deleted", "data": {"object": { "id": "sub_xyz", "customer": "cus_abc", "status": "canceled", }}, }) assert r.status_code == 200, r.text async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.stripe_customer_id, u.stripe_subscription_id tier, cid, sid = asyncio.run(_check()) assert tier == "free" # Customer linkage preserved so a future resub matches this row. assert cid == "cus_abc" assert sid is None def test_subscription_trialing_stores_trial_end(tmp_path): """customer.subscription.created with status=trialing + trial_end should grant paid AND persist the trial_end timestamp so the settings page can show 'N days remaining'. Realistic flow: checkout.session.completed fires first (linking customer_id to user.id via client_reference_id), then subscription.created fires moments later carrying trial_end.""" import datetime as _dt client, factory, _ = _build_app(tmp_path) # First: link the user to the Stripe customer via checkout. _post_webhook(client, body={ "id": "evt_link", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_trial", "subscription": "sub_trial", }}, }) # Then: the subscription event carrying trial_end (12 days out). trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc) + _dt.timedelta(days=12)).timestamp()) r = _post_webhook(client, body={ "id": "evt_trial", "type": "customer.subscription.created", "data": {"object": { "id": "sub_trial", "customer": "cus_trial", "status": "trialing", "trial_end": trial_end_ts, }}, }) assert r.status_code == 200, r.text async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.stripe_trial_end_at tier, end = asyncio.run(_check()) assert tier == "paid", "trial users must have paid features" assert end is not None # Stored value should match the trial_end we sent (within a second). expected = _dt.datetime.fromtimestamp(trial_end_ts, tz=_dt.timezone.utc) if end.tzinfo is None: end = end.replace(tzinfo=_dt.timezone.utc) assert abs((end - expected).total_seconds()) < 2 def test_subscription_active_clears_trial_end(tmp_path): """When the subscription transitions trialing -> active (day 15), the trial_end marker should be cleared so settings stops showing 'trial — N days remaining'.""" import datetime as _dt client, factory, _ = _build_app(tmp_path) # Link the customer first via checkout, then plant a trial state. _post_webhook(client, body={ "id": "evt_link2", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_t", "subscription": "sub_t", }}, }) trial_end_ts = int((_dt.datetime.now(_dt.timezone.utc) + _dt.timedelta(days=12)).timestamp()) _post_webhook(client, body={ "id": "evt_t1", "type": "customer.subscription.created", "data": {"object": { "id": "sub_t", "customer": "cus_t", "status": "trialing", "trial_end": trial_end_ts, }}, }) # Now transition to active. _post_webhook(client, body={ "id": "evt_t2", "type": "customer.subscription.updated", "data": {"object": { "id": "sub_t", "customer": "cus_t", "status": "active", }}, }) async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.stripe_trial_end_at tier, end = asyncio.run(_check()) assert tier == "paid" assert end is None, "trial_end_at must be cleared once active" def test_subscription_active_grants_paid(tmp_path): """customer.subscription.updated with status=active should also grant paid — covers the case where checkout.session.completed arrives after subscription.created and we want either to work.""" client, factory, _ = _build_app(tmp_path) # Seed the linkage first via checkout (so customer_id is known). _post_webhook(client, body={ "id": "evt_ck", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", }}, }) # Drop to free manually so we can prove the updated event re-grants. async def _reset(): from sqlalchemy import update from app.models import User async with factory() as session: await session.execute( update(User).where(User.id == 1).values(tier="free") ) await session.commit() asyncio.run(_reset()) r = _post_webhook(client, body={ "id": "evt_upd", "type": "customer.subscription.updated", "data": {"object": { "id": "sub_xyz", "customer": "cus_abc", "status": "active", }}, }) assert r.status_code == 200 async def _check_tier(): from sqlalchemy import select from app.models import User async with factory() as session: return (await session.execute( select(User.tier).where(User.id == 1) )).scalar_one() assert asyncio.run(_check_tier()) == "paid" # --- idempotency + unknown ------------------------------------------------ def test_replayed_event_id_is_a_noop(tmp_path): client, factory, _ = _build_app(tmp_path) body = { "id": "evt_dup", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", }}, } r1 = _post_webhook(client, body=body) r2 = _post_webhook(client, body=body) assert r1.json()["status"] == "ok" assert r2.json()["status"] == "duplicate" async def _count_rows(): from sqlalchemy import select, func from app.models import StripeEvent async with factory() as session: n = (await session.execute( select(func.count(StripeEvent.id)) .where(StripeEvent.event_id == "evt_dup") )).scalar_one() return n assert asyncio.run(_count_rows()) == 1 def test_unknown_event_is_acked(tmp_path): client, _, _ = _build_app(tmp_path) r = _post_webhook(client, body={ "id": "evt_unknown", "type": "product.something.new", "data": {"object": {}}, }) assert r.status_code == 200 assert r.json()["status"] == "ignored" # --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------ def _fake_checkout_client(asserter): """Build a fake Stripe client whose checkout.sessions.create calls the supplied asserter on the params dict and returns a stub URL.""" fake_session = SimpleNamespace( id="cs_test_123", url="https://checkout.stripe.com/test", ) class _FakeSessions: @staticmethod def create(params): # noqa: ANN001 asserter(params) return fake_session class _FakeCheckout: sessions = _FakeSessions() class _FakeClient: checkout = _FakeCheckout() return _FakeClient() def test_checkout_monthly_has_no_trial_and_no_stripe_consent(tmp_path): """Monthly checkout must NOT carry a free trial (£7 × 14 days would halve cycle-1 revenue) AND must NOT use Stripe's account-wide consent_collection — the Reg-36 waiver is collected on /pricing so each product can use its own Terms URL.""" client, _, session_cookie = _build_app(tmp_path) def asserter(params): assert params["mode"] == "subscription" assert params["line_items"][0]["price"] == _PRICE_MONTHLY assert params["client_reference_id"] == "1" assert params["customer_email"] == "buyer@x" assert "subscription_data" not in params, "no trial on monthly" assert "consent_collection" not in params, ( "consent is collected on /pricing, not via Stripe's account-wide setting" ) with patch("app.routers.stripe_billing._stripe_client", return_value=_fake_checkout_client(asserter)): r = client.post( "/api/stripe/checkout", json={"cadence": "monthly"}, cookies={"cassandra_session": session_cookie}, ) assert r.status_code == 200, r.text assert r.json()["url"] == "https://checkout.stripe.com/test" def test_checkout_annual_uses_trial_not_consent_collection(tmp_path): """Annual checkout gets the 14-day free trial (substitutes for the statutory cooling-off right; no money moves during the trial).""" client, _, session_cookie = _build_app(tmp_path) def asserter(params): assert params["mode"] == "subscription" assert params["line_items"][0]["price"] == _PRICE_ANNUAL assert params["subscription_data"]["trial_period_days"] == 14 assert "consent_collection" not in params, "annual relies on trial, not consent" with patch("app.routers.stripe_billing._stripe_client", return_value=_fake_checkout_client(asserter)): r = client.post( "/api/stripe/checkout", json={"cadence": "annual"}, cookies={"cassandra_session": session_cookie}, ) assert r.status_code == 200, r.text assert r.json()["url"] == "https://checkout.stripe.com/test" def test_checkout_endpoint_requires_login(tmp_path): client, _, _ = _build_app(tmp_path) r = client.post("/api/stripe/checkout", json={"cadence": "monthly"}) # No session cookie → require_auth bounces with 401. assert r.status_code == 401, r.text def test_checkout_passes_sniffed_currency_for_new_customer(tmp_path): """First-time buyer (no stripe_customer_id yet) gets the currency sniffed from the request. CF-IPCountry=US → 'usd', and Stripe will look up the USD currency_option on the Price.""" client, _, session_cookie = _build_app(tmp_path) def asserter(params): assert params["currency"] == "usd" with patch("app.routers.stripe_billing._stripe_client", return_value=_fake_checkout_client(asserter)): r = client.post( "/api/stripe/checkout", json={"cadence": "monthly"}, cookies={"cassandra_session": session_cookie}, headers={"cf-ipcountry": "US"}, ) assert r.status_code == 200, r.text def test_checkout_body_currency_overrides_sniff(tmp_path): """Explicit `currency` in the request body beats header sniffing — lets a UK-based buyer choose EUR if they want to.""" client, _, session_cookie = _build_app(tmp_path) def asserter(params): assert params["currency"] == "eur" with patch("app.routers.stripe_billing._stripe_client", return_value=_fake_checkout_client(asserter)): r = client.post( "/api/stripe/checkout", json={"cadence": "monthly", "currency": "eur"}, cookies={"cassandra_session": session_cookie}, headers={"cf-ipcountry": "GB"}, ) assert r.status_code == 200, r.text def test_checkout_omits_currency_for_existing_customer(tmp_path): """Existing customer: Stripe locked their currency at first checkout, so passing `currency` again would error. Verify we omit it (and also use the existing `customer` ref instead of customer_email).""" import asyncio from app.models import User client, factory, session_cookie = _build_app(tmp_path) async def _link(): async with factory() as s: u = await s.get(User, 1) u.stripe_customer_id = "cus_existing_xxxxxxxxxxxxxx" await s.commit() asyncio.run(_link()) def asserter(params): assert "currency" not in params, ( "currency must not be passed once a customer exists — " "Stripe rejects mismatches against the locked customer currency" ) assert params["customer"] == "cus_existing_xxxxxxxxxxxxxx" with patch("app.routers.stripe_billing._stripe_client", return_value=_fake_checkout_client(asserter)): r = client.post( "/api/stripe/checkout", json={"cadence": "monthly", "currency": "usd"}, cookies={"cassandra_session": session_cookie}, headers={"cf-ipcountry": "US"}, ) assert r.status_code == 200, r.text def test_sniff_currency_fallback_chain(): """Unit-test the header-sniffing helper: CF country wins, then Accept-Language exact, then language-only, then GBP default.""" from types import SimpleNamespace from app.routers.stripe_billing import _sniff_currency def _req(headers): return SimpleNamespace(headers=headers) assert _sniff_currency(_req({"cf-ipcountry": "DE"})) == "eur" assert _sniff_currency(_req({"cf-ipcountry": "us"})) == "usd" # case-insensitive assert _sniff_currency(_req({"accept-language": "fr-FR,fr;q=0.9"})) == "eur" assert _sniff_currency(_req({"accept-language": "en-US,en;q=0.5"})) == "usd" assert _sniff_currency(_req({"accept-language": "ja,ja-JP;q=0.5"})) == "gbp" assert _sniff_currency(_req({})) == "gbp"