diff --git a/app/routers/stripe_billing.py b/app/routers/stripe_billing.py index 60bc7f7..bfdeed0 100644 --- a/app/routers/stripe_billing.py +++ b/app/routers/stripe_billing.py @@ -19,7 +19,7 @@ from __future__ import annotations import asyncio import json -from typing import Any, Literal +from typing import Any, Literal, Optional import stripe from fastapi import APIRouter, Body, Depends, HTTPException, Request @@ -69,6 +69,53 @@ def _price_for(cadence: str) -> str: raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'") +# Rough country → currency mapping. Covers the markets we have a stated +# rate for; everything else falls back to GBP (the home currency) and +# Stripe handles the FX at checkout. Configure the per-currency +# unit_amount on each Price's `currency_options` in the Stripe Dashboard +# — we just signal which option to use here. +_COUNTRY_CURRENCY: dict[str, str] = { + "US": "usd", "CA": "usd", + "GB": "gbp", "IM": "gbp", "JE": "gbp", "GG": "gbp", + **dict.fromkeys(( + "DE", "FR", "IT", "ES", "PT", "NL", "BE", "IE", "AT", "FI", + "GR", "LU", "MT", "CY", "EE", "LV", "LT", "SI", "SK", "HR", + ), "eur"), +} + +# Accept-Language locale → currency, used when CF-IPCountry is absent. +# Ambiguous locales (e.g. plain "fr" without region) get EUR because +# that's the majority outcome. +_LOCALE_CURRENCY: dict[str, str] = { + "en-gb": "gbp", "en": "gbp", + "en-us": "usd", "en-ca": "usd", + "fr": "eur", "de": "eur", "it": "eur", "es": "eur", + "pt": "eur", "nl": "eur", +} + + +def _sniff_currency(request: Request) -> str: + """Best-effort currency detection for new-customer checkouts. + + Order: explicit Cloudflare country header, then Accept-Language + (exact match then language-only). GBP as the final fallback. Only + consulted when the user has no Stripe customer record yet — Stripe + locks currency at customer creation, so an existing customer's + currency wins regardless of the request locale. + """ + cc = (request.headers.get("cf-ipcountry") or "").upper() + if cc in _COUNTRY_CURRENCY: + return _COUNTRY_CURRENCY[cc] + al = (request.headers.get("accept-language") or "").lower() + first = al.split(",", 1)[0].split(";", 1)[0].strip() + if first in _LOCALE_CURRENCY: + return _LOCALE_CURRENCY[first] + short = first.split("-", 1)[0] + if short in _LOCALE_CURRENCY: + return _LOCALE_CURRENCY[short] + return "gbp" + + def _stripe_client() -> stripe.StripeClient: """Per-call client so we read the secret at request time (lets us rotate the key by editing .env + reloading without rebuilding any @@ -83,6 +130,10 @@ def _stripe_client() -> stripe.StripeClient: class CheckoutRequest(BaseModel): cadence: Literal["monthly", "annual"] + # Optional override; when omitted we sniff from request headers. + # Honoured only for first-time checkouts (Stripe locks currency + # to the customer at creation). + currency: Optional[Literal["gbp", "usd", "eur"]] = None class CheckoutResponse(BaseModel): @@ -92,6 +143,7 @@ class CheckoutResponse(BaseModel): @router.post("/api/stripe/checkout", response_model=CheckoutResponse) async def create_checkout( body: CheckoutRequest, + request: Request, session: AsyncSession = Depends(get_session), cu: CurrentUser = Depends(require_auth), ) -> CheckoutResponse: @@ -120,6 +172,13 @@ async def create_checkout( # referral redemption flow ships. "allow_promotion_codes": True, } + # Multi-currency: for first-time buyers (no stripe_customer_id yet) + # we pass the detected/requested currency. Stripe picks the matching + # `currency_options` rate configured on the Price in the Dashboard, + # then locks that currency to the new customer record. Existing + # customers keep their original currency regardless. + if not user.stripe_customer_id: + create_kwargs["currency"] = body.currency or _sniff_currency(request) # Per-cadence cooling-off treatment: # # - Annual gets a 14-day free trial. No money moves during the diff --git a/tests/test_stripe_billing.py b/tests/test_stripe_billing.py index f00e72d..d231cd2 100644 --- a/tests/test_stripe_billing.py +++ b/tests/test_stripe_billing.py @@ -463,3 +463,97 @@ def test_checkout_endpoint_requires_login(tmp_path): r = client.post("/api/stripe/checkout", json={"cadence": "monthly"}) # No session cookie → require_auth bounces with 401. assert r.status_code == 401, r.text + + +def test_checkout_passes_sniffed_currency_for_new_customer(tmp_path): + """First-time buyer (no stripe_customer_id yet) gets the currency + sniffed from the request. CF-IPCountry=US → 'usd', and Stripe will + look up the USD currency_option on the Price.""" + client, _, session_cookie = _build_app(tmp_path) + + def asserter(params): + assert params["currency"] == "usd" + + with patch("app.routers.stripe_billing._stripe_client", + return_value=_fake_checkout_client(asserter)): + r = client.post( + "/api/stripe/checkout", + json={"cadence": "monthly"}, + cookies={"cassandra_session": session_cookie}, + headers={"cf-ipcountry": "US"}, + ) + assert r.status_code == 200, r.text + + +def test_checkout_body_currency_overrides_sniff(tmp_path): + """Explicit `currency` in the request body beats header sniffing — + lets a UK-based buyer choose EUR if they want to.""" + client, _, session_cookie = _build_app(tmp_path) + + def asserter(params): + assert params["currency"] == "eur" + + with patch("app.routers.stripe_billing._stripe_client", + return_value=_fake_checkout_client(asserter)): + r = client.post( + "/api/stripe/checkout", + json={"cadence": "monthly", "currency": "eur"}, + cookies={"cassandra_session": session_cookie}, + headers={"cf-ipcountry": "GB"}, + ) + assert r.status_code == 200, r.text + + +def test_checkout_omits_currency_for_existing_customer(tmp_path): + """Existing customer: Stripe locked their currency at first + checkout, so passing `currency` again would error. Verify we omit + it (and also use the existing `customer` ref instead of + customer_email).""" + import asyncio + + from app.models import User + + client, factory, session_cookie = _build_app(tmp_path) + + async def _link(): + async with factory() as s: + u = await s.get(User, 1) + u.stripe_customer_id = "cus_existing_xxxxxxxxxxxxxx" + await s.commit() + + asyncio.run(_link()) + + def asserter(params): + assert "currency" not in params, ( + "currency must not be passed once a customer exists — " + "Stripe rejects mismatches against the locked customer currency" + ) + assert params["customer"] == "cus_existing_xxxxxxxxxxxxxx" + + with patch("app.routers.stripe_billing._stripe_client", + return_value=_fake_checkout_client(asserter)): + r = client.post( + "/api/stripe/checkout", + json={"cadence": "monthly", "currency": "usd"}, + cookies={"cassandra_session": session_cookie}, + headers={"cf-ipcountry": "US"}, + ) + assert r.status_code == 200, r.text + + +def test_sniff_currency_fallback_chain(): + """Unit-test the header-sniffing helper: CF country wins, then + Accept-Language exact, then language-only, then GBP default.""" + from types import SimpleNamespace + + from app.routers.stripe_billing import _sniff_currency + + def _req(headers): + return SimpleNamespace(headers=headers) + + assert _sniff_currency(_req({"cf-ipcountry": "DE"})) == "eur" + assert _sniff_currency(_req({"cf-ipcountry": "us"})) == "usd" # case-insensitive + assert _sniff_currency(_req({"accept-language": "fr-FR,fr;q=0.9"})) == "eur" + assert _sniff_currency(_req({"accept-language": "en-US,en;q=0.5"})) == "usd" + assert _sniff_currency(_req({"accept-language": "ja,ja-JP;q=0.5"})) == "gbp" + assert _sniff_currency(_req({})) == "gbp"