"""Stripe billing endpoints: signature verification, idempotency, tier flips, and checkout creation. Same integration-style scaffold as test_polar_webhook.py — real router over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal sessions.create) are mocked so the suite never makes a real HTTP call. """ from __future__ import annotations import asyncio import hashlib import hmac import json import time from types import SimpleNamespace from unittest.mock import patch import pytest _API_KEY = "sk_test_dummy_for_unit_tests" _WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests" _PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx" _PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx" def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str: """Produce a Stripe-Signature header matching the bytes signed. Format: `t=,v1=` over `.`.""" ts = ts if ts is not None else int(time.time()) signed = f"{ts}.{body.decode('utf-8')}" mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"), hashlib.sha256).hexdigest() return f"t={ts},v1={mac}" def _build_app(tmp_path): from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from app import db as db_mod from app.auth import sign_session from app.config import get_settings from app.db import Base from app.models import User from app.routers import stripe_billing as stripe_router s = get_settings() s.STRIPE_API_KEY = _API_KEY # type: ignore[misc] s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc] s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc] s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc] engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db") factory = async_sessionmaker(engine, expire_on_commit=False) db_mod._engine = engine db_mod._session_factory = factory async def _seed(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async with factory() as session: session.add(User(id=1, email="buyer@x", tier="free")) await session.commit() asyncio.run(_seed()) app = FastAPI() app.include_router(stripe_router.router) return TestClient(app), factory, sign_session(1) def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET, sig: str | None = None): # Stripe's SDK requires a top-level `object: "event"` field to know # this is a v1 webhook envelope — tests that omit it fail in # construct_event before the signature check matters. We inject the # default here so individual tests can stay terse. body.setdefault("object", "event") raw = json.dumps(body).encode("utf-8") sig = sig if sig is not None else _stripe_sig(raw, secret) return client.post( "/api/stripe/webhook", content=raw, headers={"stripe-signature": sig, "content-type": "application/json"}, ) # --- signature gate -------------------------------------------------------- def test_webhook_rejects_bad_signature(tmp_path): client, _, _ = _build_app(tmp_path) raw = json.dumps({"id": "evt_x", "type": "invoice.paid", "data": {"object": {}}}).encode("utf-8") r = client.post( "/api/stripe/webhook", content=raw, headers={ "stripe-signature": "t=0,v1=deadbeef", "content-type": "application/json", }, ) assert r.status_code == 401, r.text def test_webhook_rejects_missing_signature(tmp_path): client, _, _ = _build_app(tmp_path) r = client.post( "/api/stripe/webhook", content=b"{}", headers={"content-type": "application/json"}, ) assert r.status_code == 400, r.text # --- happy paths ----------------------------------------------------------- def test_checkout_session_completed_flips_tier_to_paid(tmp_path): client, factory, _ = _build_app(tmp_path) body = { "id": "evt_checkout_1", "type": "checkout.session.completed", "data": { "object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", } }, } r = _post_webhook(client, body=body) assert r.status_code == 200, r.text assert r.json()["status"] == "ok" async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.stripe_customer_id, u.stripe_subscription_id tier, cid, sid = asyncio.run(_check()) assert tier == "paid" assert cid == "cus_abc" assert sid == "sub_xyz" def test_subscription_deleted_drops_tier_to_free(tmp_path): client, factory, _ = _build_app(tmp_path) # First, activate. _post_webhook(client, body={ "id": "evt_act", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", }}, }) # Then, delete the subscription. r = _post_webhook(client, body={ "id": "evt_del", "type": "customer.subscription.deleted", "data": {"object": { "id": "sub_xyz", "customer": "cus_abc", "status": "canceled", }}, }) assert r.status_code == 200, r.text async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.stripe_customer_id, u.stripe_subscription_id tier, cid, sid = asyncio.run(_check()) assert tier == "free" # Customer linkage preserved so a future resub matches this row. assert cid == "cus_abc" assert sid is None def test_subscription_active_grants_paid(tmp_path): """customer.subscription.updated with status=active should also grant paid — covers the case where checkout.session.completed arrives after subscription.created and we want either to work.""" client, factory, _ = _build_app(tmp_path) # Seed the linkage first via checkout (so customer_id is known). _post_webhook(client, body={ "id": "evt_ck", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", }}, }) # Drop to free manually so we can prove the updated event re-grants. async def _reset(): from sqlalchemy import update from app.models import User async with factory() as session: await session.execute( update(User).where(User.id == 1).values(tier="free") ) await session.commit() asyncio.run(_reset()) r = _post_webhook(client, body={ "id": "evt_upd", "type": "customer.subscription.updated", "data": {"object": { "id": "sub_xyz", "customer": "cus_abc", "status": "active", }}, }) assert r.status_code == 200 async def _check_tier(): from sqlalchemy import select from app.models import User async with factory() as session: return (await session.execute( select(User.tier).where(User.id == 1) )).scalar_one() assert asyncio.run(_check_tier()) == "paid" # --- idempotency + unknown ------------------------------------------------ def test_replayed_event_id_is_a_noop(tmp_path): client, factory, _ = _build_app(tmp_path) body = { "id": "evt_dup", "type": "checkout.session.completed", "data": {"object": { "client_reference_id": "1", "customer": "cus_abc", "subscription": "sub_xyz", }}, } r1 = _post_webhook(client, body=body) r2 = _post_webhook(client, body=body) assert r1.json()["status"] == "ok" assert r2.json()["status"] == "duplicate" async def _count_rows(): from sqlalchemy import select, func from app.models import StripeEvent async with factory() as session: n = (await session.execute( select(func.count(StripeEvent.id)) .where(StripeEvent.event_id == "evt_dup") )).scalar_one() return n assert asyncio.run(_count_rows()) == 1 def test_unknown_event_is_acked(tmp_path): client, _, _ = _build_app(tmp_path) r = _post_webhook(client, body={ "id": "evt_unknown", "type": "product.something.new", "data": {"object": {}}, }) assert r.status_code == 200 assert r.json()["status"] == "ignored" # --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------ def test_checkout_endpoint_creates_session_and_returns_url(tmp_path): client, _, session_cookie = _build_app(tmp_path) # Mock the Stripe SDK call so no real HTTP goes out. fake_session = SimpleNamespace( id="cs_test_123", url="https://checkout.stripe.com/test", ) class _FakeSessions: @staticmethod def create(params): # noqa: ANN001 assert params["mode"] == "subscription" assert params["line_items"][0]["price"] == _PRICE_MONTHLY assert params["client_reference_id"] == "1" assert params["customer_email"] == "buyer@x" return fake_session class _FakeCheckout: sessions = _FakeSessions() class _FakeClient: checkout = _FakeCheckout() with patch("app.routers.stripe_billing._stripe_client", return_value=_FakeClient()): r = client.post( "/api/stripe/checkout", json={"cadence": "monthly"}, cookies={"cassandra_session": session_cookie}, ) assert r.status_code == 200, r.text assert r.json()["url"] == "https://checkout.stripe.com/test" def test_checkout_endpoint_requires_login(tmp_path): client, _, _ = _build_app(tmp_path) r = client.post("/api/stripe/checkout", json={"cadence": "monthly"}) # No session cookie → require_auth bounces with 401. assert r.status_code == 401, r.text