stripe: wire checkout, customer portal, and webhook for read.markets
Stripe is the merchant-on-record for read.markets after Polar/Paddle
both declined the financial-media category. This commit lands the
full subscription flow: an "Upgrade" button on /pricing now opens a
real Stripe-hosted Checkout, completes the subscription, and the
webhook flips user.tier to "paid" idempotently.
Endpoints
- POST /api/stripe/checkout (require_auth) — creates a hosted
Checkout Session in subscription mode, passes user.id as
client_reference_id + email as customer_email, returns the URL
for the page-side JS to redirect to. Reuses an existing
stripe_customer_id to avoid duplicate Stripe customers on repeat
checkouts. allow_promotion_codes=True so the referral-credit
redemption can attach a coupon at checkout once that flow ships.
- POST /api/stripe/portal (require_auth) — mints a Stripe Customer
Portal session. Used by /settings; returns 404 until the user has
a stripe_customer_id (i.e. completed at least one checkout).
- POST /api/stripe/webhook — signature-verified via
stripe.Webhook.construct_event. Idempotent via UNIQUE on
stripe_events.event_id. Event dispatch:
checkout.session.completed → grant paid, store IDs
customer.subscription.created → grant paid (active/trialing)
customer.subscription.updated → grant paid (active/trialing)
customer.subscription.deleted → drop to free, clear sub id
invoice.paid / failed → audit only
charge.refunded → audit only
Stripe-SDK objects don't expose dict.get(); we use the SDK for
signature verification then re-parse the JSON body for handler
dispatch — cleaner than reaching into StripeObject internals.
Schema (migration 0019)
- users.stripe_customer_id, users.stripe_subscription_id (nullable
String(64), UNIQUE on customer_id).
- stripe_events table mirroring polar_events: event_id (unique),
event_type, received_at, processed_at, error, raw payload
(truncated to 16 KiB).
Settings (.env)
- STRIPE_API_KEY (rk_test_… for dev, rk_live_… for GA)
- STRIPE_WEBHOOK_SECRET (whsec_… from the dashboard endpoint)
- STRIPE_PRICE_MONTHLY (price_xxx for £7/month)
- STRIPE_PRICE_ANNUAL (price_xxx for £70/year)
Pricing page
- Free tier CTA unchanged.
- Paid CTA branches three ways: paid → "Manage subscription" to
/settings; logged-in free → two buttons (£7/mo, £70/yr) that POST
to /api/stripe/checkout and redirect; anonymous → /login?next=/pricing.
- Inline JS intercepts the button click, calls the checkout
endpoint, redirects on success, surfaces errors via alert(). No
Stripe.js dep — we use the hosted-checkout URL directly.
Polar handler stays in place for berengar.io / flyroom.net which
still ship through Polar. polar_* and stripe_* columns coexist
independently on the User row.
Tests
- 9 in tests/test_stripe_billing.py covering: bad signature → 401,
missing signature → 400, checkout.session.completed flips tier +
stores IDs, subscription.updated active grants paid,
subscription.deleted drops to free with customer id preserved,
replayed event id is no-op (one row in stripe_events),
unknown event acked 200, checkout endpoint mocks the SDK and
returns the hosted URL, checkout requires login.
- Full suite: 221 passed, 5 skipped.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
6c13f855e9
commit
410afe0078
9 changed files with 858 additions and 7 deletions
56
alembic/versions/0019_stripe.py
Normal file
56
alembic/versions/0019_stripe.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""stripe integration: users.stripe_customer_id / stripe_subscription_id,
|
||||
stripe_events table.
|
||||
|
||||
Revision ID: 0019
|
||||
Revises: 0018
|
||||
Create Date: 2026-05-26
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0019"
|
||||
down_revision: Union[str, None] = "0018"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("stripe_customer_id", sa.String(length=64), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("stripe_subscription_id", sa.String(length=64), nullable=True),
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
"uq_users_stripe_customer", "users", ["stripe_customer_id"],
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"stripe_events",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, primary_key=True),
|
||||
sa.Column("event_id", sa.String(length=128), nullable=False),
|
||||
sa.Column("event_type", sa.String(length=64), nullable=False),
|
||||
sa.Column("received_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("processed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("payload", sa.Text(), nullable=False),
|
||||
sa.UniqueConstraint("event_id", name="uq_stripe_events_event_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_stripe_events_type_received",
|
||||
"stripe_events",
|
||||
["event_type", "received_at"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_stripe_events_type_received", table_name="stripe_events")
|
||||
op.drop_table("stripe_events")
|
||||
op.drop_constraint("uq_users_stripe_customer", "users", type_="unique")
|
||||
op.drop_column("users", "stripe_subscription_id")
|
||||
op.drop_column("users", "stripe_customer_id")
|
||||
|
|
@ -99,6 +99,16 @@ class Settings(BaseSettings):
|
|||
POLAR_WEBHOOK_SECRET: str = ""
|
||||
POLAR_API_KEY: str = ""
|
||||
|
||||
# Stripe (merchant-on-record for read.markets after Polar/Paddle
|
||||
# both declined the financial-media category). Test-mode keys are
|
||||
# `sk_test_*` / `whsec_*`; live-mode keys are `sk_live_*` — swap at
|
||||
# GA cutover. Empty values make the corresponding endpoints 503 so
|
||||
# a misconfig is loud rather than silently accepting unsigned events.
|
||||
STRIPE_API_KEY: str = ""
|
||||
STRIPE_WEBHOOK_SECRET: str = ""
|
||||
STRIPE_PRICE_MONTHLY: str = "" # price_xxx for £7/month subscription
|
||||
STRIPE_PRICE_ANNUAL: str = "" # price_xxx for £70/year subscription
|
||||
|
||||
# Config file locations (overridable for tests)
|
||||
BASELINE_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "default.toml")
|
||||
PORTFOLIO_TOML: Path = Field(default_factory=lambda: CONFIG_DIR / "portfolio.toml")
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from app.routers import email as email_router
|
|||
from app.routers import pages as pages_router
|
||||
from app.routers import polar_webhook as polar_webhook_router
|
||||
from app.routers import public as public_router
|
||||
from app.routers import stripe_billing as stripe_router
|
||||
from app.routers import sync as sync_router
|
||||
from app.routers import universe as universe_router
|
||||
from app.services.feeds_bootstrap import bootstrap_feeds
|
||||
|
|
@ -93,6 +94,9 @@ app.include_router(sync_router.router, tags=["portfolio-sync"])
|
|||
# `/api/polar/webhook` is set on the route itself so the URL Polar
|
||||
# stores remains stable even if api_router's prefix ever moves.
|
||||
app.include_router(polar_webhook_router.router, tags=["polar-webhook"])
|
||||
# Stripe billing (checkout, portal, webhook). Auth lives per-route:
|
||||
# checkout + portal require_auth, webhook is signature-gated.
|
||||
app.include_router(stripe_router.router, tags=["stripe-billing"])
|
||||
# Public router (no auth dep) before pages_router so the marketing/legal
|
||||
# paths can never collide with future authenticated routes.
|
||||
app.include_router(public_router.router)
|
||||
|
|
|
|||
|
|
@ -194,11 +194,18 @@ class User(Base):
|
|||
# we cancel against from /settings.
|
||||
polar_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
polar_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
# Stripe (merchant-on-record for read.markets). Populated on the
|
||||
# first checkout.session.completed event via client_reference_id;
|
||||
# used thereafter to match incoming subscription/invoice events
|
||||
# back to this row.
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
stripe_subscription_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("email", name="uq_users_email"),
|
||||
UniqueConstraint("referral_code", name="uq_users_referral_code"),
|
||||
UniqueConstraint("polar_customer_id", name="uq_users_polar_customer"),
|
||||
UniqueConstraint("stripe_customer_id", name="uq_users_stripe_customer"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -385,3 +392,29 @@ class PolarEvent(Base):
|
|||
UniqueConstraint("event_id", name="uq_polar_events_event_id"),
|
||||
Index("ix_polar_events_type_received", "event_type", "received_at"),
|
||||
)
|
||||
|
||||
|
||||
class StripeEvent(Base):
|
||||
"""Audit + idempotency table for inbound Stripe webhook deliveries.
|
||||
|
||||
Same shape and purpose as PolarEvent — Stripe's `event.id` plays the
|
||||
same role as Standard Webhooks' `webhook-id`. We keep the tables
|
||||
distinct (rather than a single 'webhook_events' table) so an
|
||||
operator can look at the audit trail per processor without filtering
|
||||
on a `source` column."""
|
||||
__tablename__ = "stripe_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(_PK, primary_key=True, autoincrement=True)
|
||||
event_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
event_type: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
received_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=utcnow, nullable=False,
|
||||
)
|
||||
processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
error: Mapped[str | None] = mapped_column(Text)
|
||||
payload: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("event_id", name="uq_stripe_events_event_id"),
|
||||
Index("ix_stripe_events_type_received", "event_type", "received_at"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from fastapi import APIRouter, Depends, Request
|
|||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.auth import CurrentUser, maybe_current_user
|
||||
from app.services.access import is_paid_active
|
||||
from app.templates_env import templates
|
||||
|
||||
|
||||
|
|
@ -33,7 +34,9 @@ async def pricing_page(
|
|||
request: Request,
|
||||
cu: CurrentUser | None = Depends(maybe_current_user),
|
||||
):
|
||||
return templates.TemplateResponse(request, "pricing.html", _ctx(request, cu))
|
||||
ctx = _ctx(request, cu)
|
||||
ctx["paid"] = is_paid_active(cu)
|
||||
return templates.TemplateResponse(request, "pricing.html", ctx)
|
||||
|
||||
|
||||
@router.get("/about", response_class=HTMLResponse)
|
||||
|
|
|
|||
383
app/routers/stripe_billing.py
Normal file
383
app/routers/stripe_billing.py
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
"""Stripe billing endpoints — checkout, webhook, customer portal.
|
||||
|
||||
Stripe is the merchant-on-record for read.markets (after Polar/Paddle
|
||||
both declined the financial-media category). We delegate payment UI to
|
||||
Stripe-hosted Checkout and Customer Portal; the only state we keep on
|
||||
our side is `users.stripe_customer_id` / `users.stripe_subscription_id`
|
||||
so we can match incoming webhooks back to the right user.
|
||||
|
||||
The Stripe SDK is sync; we wrap calls in `asyncio.to_thread` so the
|
||||
event loop doesn't block while Stripe answers. For our request volume
|
||||
this is more reliable than the SDK's nascent async surface.
|
||||
|
||||
Routes
|
||||
- POST /api/stripe/checkout — logged-in user upgrades. Body: {cadence}.
|
||||
- POST /api/stripe/webhook — Stripe → us, signature-verified.
|
||||
- POST /api/stripe/portal — logged-in user opens the customer portal.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
import stripe
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import branding
|
||||
from app.auth import CurrentUser, require_auth
|
||||
from app.config import get_settings
|
||||
from app.db import get_session, utcnow
|
||||
from app.logging import get_logger
|
||||
from app.models import StripeEvent, User
|
||||
|
||||
|
||||
log = get_logger("stripe_billing")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Cap stored payload at 16 KiB so a hostile (or buggy) sender can't
|
||||
# blow up a single row. Same pattern as polar_webhook.
|
||||
_PAYLOAD_STORE_MAX = 16 * 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _require_configured() -> None:
|
||||
s = get_settings()
|
||||
if not s.STRIPE_API_KEY:
|
||||
raise HTTPException(status_code=503, detail="stripe not configured")
|
||||
|
||||
|
||||
def _price_for(cadence: str) -> str:
|
||||
s = get_settings()
|
||||
if cadence == "monthly":
|
||||
if not s.STRIPE_PRICE_MONTHLY:
|
||||
raise HTTPException(status_code=503, detail="STRIPE_PRICE_MONTHLY not set")
|
||||
return s.STRIPE_PRICE_MONTHLY
|
||||
if cadence == "annual":
|
||||
if not s.STRIPE_PRICE_ANNUAL:
|
||||
raise HTTPException(status_code=503, detail="STRIPE_PRICE_ANNUAL not set")
|
||||
return s.STRIPE_PRICE_ANNUAL
|
||||
raise HTTPException(status_code=400, detail="cadence must be 'monthly' or 'annual'")
|
||||
|
||||
|
||||
def _stripe_client() -> stripe.StripeClient:
|
||||
"""Per-call client so we read the secret at request time (lets us
|
||||
rotate the key by editing .env + reloading without rebuilding any
|
||||
cached client)."""
|
||||
return stripe.StripeClient(get_settings().STRIPE_API_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/stripe/checkout
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CheckoutRequest(BaseModel):
|
||||
cadence: Literal["monthly", "annual"]
|
||||
|
||||
|
||||
class CheckoutResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/api/stripe/checkout", response_model=CheckoutResponse)
|
||||
async def create_checkout(
|
||||
body: CheckoutRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
cu: CurrentUser = Depends(require_auth),
|
||||
) -> CheckoutResponse:
|
||||
_require_configured()
|
||||
if cu.user is None:
|
||||
# Admin bearer token has no User row — they shouldn't be buying.
|
||||
raise HTTPException(status_code=400, detail="admin token cannot purchase")
|
||||
|
||||
user = await session.get(User, cu.user.id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="user_not_found")
|
||||
|
||||
price_id = _price_for(body.cadence)
|
||||
client = _stripe_client()
|
||||
|
||||
# Pass `customer` if we already minted one for this user (avoids
|
||||
# creating duplicate Stripe customers on repeat checkouts);
|
||||
# otherwise let Stripe create it via `customer_email`.
|
||||
create_kwargs: dict[str, Any] = {
|
||||
"mode": "subscription",
|
||||
"line_items": [{"price": price_id, "quantity": 1}],
|
||||
"client_reference_id": str(user.id),
|
||||
"success_url": f"{branding.SITE_URL}/settings?upgraded=1",
|
||||
"cancel_url": f"{branding.SITE_URL}/pricing",
|
||||
# Lets us paste in a referral coupon at checkout once the
|
||||
# referral redemption flow ships.
|
||||
"allow_promotion_codes": True,
|
||||
}
|
||||
if user.stripe_customer_id:
|
||||
create_kwargs["customer"] = user.stripe_customer_id
|
||||
else:
|
||||
create_kwargs["customer_email"] = user.email
|
||||
|
||||
try:
|
||||
sess = await asyncio.to_thread(
|
||||
client.checkout.sessions.create, params=create_kwargs,
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
log.error("stripe.checkout.create_failed", user_id=user.id, error=str(e))
|
||||
raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
|
||||
|
||||
if not sess.url:
|
||||
raise HTTPException(status_code=502, detail="stripe returned no checkout URL")
|
||||
log.info("stripe.checkout.created", user_id=user.id, session_id=sess.id,
|
||||
cadence=body.cadence)
|
||||
return CheckoutResponse(url=sess.url)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/stripe/portal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PortalResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/api/stripe/portal", response_model=PortalResponse)
|
||||
async def create_portal_session(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
cu: CurrentUser = Depends(require_auth),
|
||||
) -> PortalResponse:
|
||||
_require_configured()
|
||||
if cu.user is None:
|
||||
raise HTTPException(status_code=400, detail="admin token has no portal")
|
||||
|
||||
user = await session.get(User, cu.user.id)
|
||||
if user is None or not user.stripe_customer_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="no_stripe_customer — start a subscription first",
|
||||
)
|
||||
|
||||
client = _stripe_client()
|
||||
try:
|
||||
portal = await asyncio.to_thread(
|
||||
client.billing_portal.sessions.create,
|
||||
params={
|
||||
"customer": user.stripe_customer_id,
|
||||
"return_url": f"{branding.SITE_URL}/settings",
|
||||
},
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
log.error("stripe.portal.create_failed", user_id=user.id, error=str(e))
|
||||
raise HTTPException(status_code=502, detail=f"stripe error: {e.user_message or str(e)}")
|
||||
|
||||
return PortalResponse(url=portal.url)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/stripe/webhook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _find_user(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
client_ref: str | None = None,
|
||||
customer_id: str | None = None,
|
||||
) -> User | None:
|
||||
"""Find the User row this event belongs to.
|
||||
|
||||
`client_reference_id` is the most reliable join key — we set it
|
||||
to `str(user.id)` at checkout creation. After the first event we
|
||||
also know `stripe_customer_id`, which subsequent subscription /
|
||||
invoice events arrive carrying."""
|
||||
if client_ref:
|
||||
try:
|
||||
uid = int(client_ref)
|
||||
except ValueError:
|
||||
uid = None
|
||||
if uid is not None:
|
||||
u = await session.get(User, uid)
|
||||
if u is not None:
|
||||
return u
|
||||
if customer_id:
|
||||
row = (await session.execute(
|
||||
select(User).where(User.stripe_customer_id == customer_id)
|
||||
)).scalar_one_or_none()
|
||||
return row
|
||||
return None
|
||||
|
||||
|
||||
async def _grant_paid(
|
||||
user: User,
|
||||
*,
|
||||
customer_id: str | None,
|
||||
subscription_id: str | None,
|
||||
) -> None:
|
||||
user.tier = "paid"
|
||||
if customer_id and user.stripe_customer_id != customer_id:
|
||||
user.stripe_customer_id = customer_id
|
||||
if subscription_id and user.stripe_subscription_id != subscription_id:
|
||||
user.stripe_subscription_id = subscription_id
|
||||
|
||||
|
||||
async def _revoke_paid(user: User) -> None:
|
||||
user.tier = "free"
|
||||
user.stripe_subscription_id = None
|
||||
# Keep stripe_customer_id so a re-subscription matches this row.
|
||||
|
||||
|
||||
async def _handle_checkout_completed(
|
||||
session: AsyncSession, event_type: str, obj: dict[str, Any],
|
||||
) -> None:
|
||||
user = await _find_user(
|
||||
session,
|
||||
client_ref=obj.get("client_reference_id"),
|
||||
customer_id=obj.get("customer"),
|
||||
)
|
||||
if user is None:
|
||||
log.warning("stripe.user_not_found", event=event_type)
|
||||
return
|
||||
await _grant_paid(
|
||||
user,
|
||||
customer_id=obj.get("customer"),
|
||||
subscription_id=obj.get("subscription"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_subscription_event(
|
||||
session: AsyncSession, event_type: str, obj: dict[str, Any],
|
||||
) -> None:
|
||||
"""customer.subscription.created / .updated — flip to paid if the
|
||||
Stripe-side status says the subscription is active/trialing; drop
|
||||
to free if it's an end-state."""
|
||||
user = await _find_user(session, customer_id=obj.get("customer"))
|
||||
if user is None:
|
||||
log.warning("stripe.user_not_found", event=event_type,
|
||||
customer_id=obj.get("customer"))
|
||||
return
|
||||
status = obj.get("status")
|
||||
# Stripe statuses: trialing, active, past_due, canceled, unpaid,
|
||||
# incomplete, incomplete_expired, paused. Treat trialing/active as
|
||||
# paid; everything else holds tier the same until we get an explicit
|
||||
# subscription.deleted (which fires after the final state lands).
|
||||
if status in ("trialing", "active"):
|
||||
await _grant_paid(
|
||||
user,
|
||||
customer_id=obj.get("customer"),
|
||||
subscription_id=obj.get("id"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_subscription_deleted(
|
||||
session: AsyncSession, event_type: str, obj: dict[str, Any],
|
||||
) -> None:
|
||||
user = await _find_user(session, customer_id=obj.get("customer"))
|
||||
if user is None:
|
||||
log.warning("stripe.user_not_found", event=event_type,
|
||||
customer_id=obj.get("customer"))
|
||||
return
|
||||
await _revoke_paid(user)
|
||||
|
||||
|
||||
async def _handle_audit_only(
|
||||
session: AsyncSession, event_type: str, obj: dict[str, Any],
|
||||
) -> None:
|
||||
"""invoice.paid / invoice.payment_failed / charge.refunded — we
|
||||
record these in stripe_events for the audit log but the tier doesn't
|
||||
move until subscription.deleted fires."""
|
||||
return None
|
||||
|
||||
|
||||
_HANDLERS = {
|
||||
"checkout.session.completed": _handle_checkout_completed,
|
||||
"customer.subscription.created": _handle_subscription_event,
|
||||
"customer.subscription.updated": _handle_subscription_event,
|
||||
"customer.subscription.deleted": _handle_subscription_deleted,
|
||||
"invoice.paid": _handle_audit_only,
|
||||
"invoice.payment_failed": _handle_audit_only,
|
||||
"charge.refunded": _handle_audit_only,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/api/stripe/webhook")
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
s = get_settings()
|
||||
if not s.STRIPE_WEBHOOK_SECRET:
|
||||
raise HTTPException(status_code=503, detail="stripe webhook not configured")
|
||||
|
||||
sig = request.headers.get("stripe-signature", "")
|
||||
if not sig:
|
||||
raise HTTPException(status_code=400, detail="missing stripe-signature header")
|
||||
|
||||
body = await request.body()
|
||||
# construct_event handles HMAC verification + timestamp tolerance.
|
||||
# We then re-parse the body as plain JSON for handler dispatch —
|
||||
# the Stripe SDK's StripeObject doesn't expose dict.get(), and
|
||||
# round-tripping through json gives us simple, typed-dict access.
|
||||
try:
|
||||
stripe.Webhook.construct_event(
|
||||
payload=body, sig_header=sig, secret=s.STRIPE_WEBHOOK_SECRET,
|
||||
)
|
||||
except stripe.SignatureVerificationError:
|
||||
raise HTTPException(status_code=401, detail="bad signature")
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="invalid payload")
|
||||
|
||||
envelope = json.loads(body)
|
||||
event_id = envelope.get("id") or ""
|
||||
event_type = envelope.get("type") or "unknown"
|
||||
obj = (envelope.get("data") or {}).get("object") or {}
|
||||
|
||||
if not event_id:
|
||||
raise HTTPException(status_code=400, detail="event missing id")
|
||||
|
||||
# Idempotency: insert audit row first. UNIQUE on event_id makes a
|
||||
# replay of the same Stripe event id a no-op (Stripe retries on
|
||||
# non-2xx, so always 2xx after first successful processing).
|
||||
audit = StripeEvent(
|
||||
event_id=event_id,
|
||||
event_type=event_type,
|
||||
received_at=utcnow(),
|
||||
payload=body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX],
|
||||
)
|
||||
session.add(audit)
|
||||
try:
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
log.info("stripe.duplicate_delivery", event_id=event_id, type=event_type)
|
||||
return {"status": "duplicate"}
|
||||
|
||||
handler = _HANDLERS.get(event_type)
|
||||
if handler is None:
|
||||
audit.processed_at = utcnow()
|
||||
await session.commit()
|
||||
log.info("stripe.event_unhandled", type=event_type, id=event_id)
|
||||
return {"status": "ignored"}
|
||||
|
||||
try:
|
||||
await handler(session, event_type, obj)
|
||||
except Exception as e:
|
||||
audit.error = str(e)[:1024]
|
||||
await session.commit()
|
||||
log.exception("stripe.handler_error", type=event_type, id=event_id)
|
||||
# Ack 200 — we don't want Stripe retrying a handler that broke
|
||||
# the same way on every delivery. An operator triages from the
|
||||
# `error` column.
|
||||
return {"status": "handler_error"}
|
||||
|
||||
audit.processed_at = utcnow()
|
||||
await session.commit()
|
||||
log.info("stripe.processed", type=event_type, id=event_id)
|
||||
return {"status": "ok"}
|
||||
|
|
@ -50,9 +50,8 @@
|
|||
<div class="tier-card__tagline">Full-day news feed, hourly strategic log, follow-up chat, and AI portfolio analysis.</div>
|
||||
<div class="tier-card__price">£7<span class="tier-card__price-unit"> / month</span></div>
|
||||
<div class="tier-card__price-hint">
|
||||
Or <strong>£70 / year</strong> — two months free. Prices
|
||||
in GBP, VAT where applicable. Checkout opens with the payments
|
||||
rollout.
|
||||
Or <strong>£70 / year</strong> — two months free.
|
||||
Prices in GBP, VAT where applicable.
|
||||
</div>
|
||||
<div class="tier-card__divider"></div>
|
||||
<div class="tier-card__list-head">Everything in Free, plus</div>
|
||||
|
|
@ -72,16 +71,56 @@
|
|||
or a personal recommendation under FSMA / FCA COBS.
|
||||
</p>
|
||||
<div class="tier-card__cta">
|
||||
{% if cu and (cu.user or cu.is_admin) %}
|
||||
<a class="btn-secondary btn-block" href="/settings">Manage account</a>
|
||||
{% if paid %}
|
||||
<a class="btn-secondary btn-block" href="/settings">Manage subscription</a>
|
||||
{% elif cu and cu.user %}
|
||||
<button class="btn-primary btn-block" type="button"
|
||||
data-stripe-checkout="monthly">Subscribe — £7/month</button>
|
||||
<button class="btn-secondary btn-block" type="button"
|
||||
data-stripe-checkout="annual"
|
||||
style="margin-top:10px;">or £70/year (two months free)</button>
|
||||
{% else %}
|
||||
<a class="btn-primary btn-block" href="/login">Sign up — paid unlocks soon</a>
|
||||
<a class="btn-primary btn-block" href="/login?next=/pricing">Sign in to subscribe</a>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
<script>
|
||||
(function () {
|
||||
// Wire the two upgrade buttons to /api/stripe/checkout. Stripe returns
|
||||
// a hosted-checkout URL; we just redirect there. No Stripe.js needed.
|
||||
document.querySelectorAll('[data-stripe-checkout]').forEach(function (btn) {
|
||||
btn.addEventListener('click', async function () {
|
||||
var cadence = btn.getAttribute('data-stripe-checkout');
|
||||
btn.disabled = true;
|
||||
var prev = btn.textContent;
|
||||
btn.textContent = 'Opening checkout…';
|
||||
try {
|
||||
var r = await fetch('/api/stripe/checkout', {
|
||||
method: 'POST',
|
||||
headers: {'content-type': 'application/json'},
|
||||
body: JSON.stringify({cadence: cadence}),
|
||||
credentials: 'same-origin',
|
||||
});
|
||||
if (!r.ok) {
|
||||
var detail = '';
|
||||
try { detail = (await r.json()).detail || ''; } catch (e) {}
|
||||
throw new Error('Checkout failed: ' + (detail || r.status));
|
||||
}
|
||||
var data = await r.json();
|
||||
window.location.href = data.url;
|
||||
} catch (e) {
|
||||
alert(e.message || 'Could not start checkout. Please try again.');
|
||||
btn.disabled = false;
|
||||
btn.textContent = prev;
|
||||
}
|
||||
});
|
||||
});
|
||||
})();
|
||||
</script>
|
||||
|
||||
<section class="public-section">
|
||||
<h2 class="public-section__head">Free vs Paid at a glance</h2>
|
||||
<table class="compare-table">
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ dependencies = [
|
|||
"email-validator>=2.2",
|
||||
"aiosmtplib>=3.0",
|
||||
"redis[hiredis]>=5.2",
|
||||
"stripe>=11.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
322
tests/test_stripe_billing.py
Normal file
322
tests/test_stripe_billing.py
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
"""Stripe billing endpoints: signature verification, idempotency, tier
|
||||
flips, and checkout creation.
|
||||
|
||||
Same integration-style scaffold as test_polar_webhook.py — real router
|
||||
over in-memory aiosqlite. Stripe SDK calls (sessions.create, portal
|
||||
sessions.create) are mocked so the suite never makes a real HTTP call.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
_API_KEY = "sk_test_dummy_for_unit_tests"
|
||||
_WEBHOOK_SECRET = "whsec_dummy_test_secret_for_unit_tests"
|
||||
_PRICE_MONTHLY = "price_test_monthly_xxxxxxxxxxxxxxxxxxxx"
|
||||
_PRICE_ANNUAL = "price_test_annual_xxxxxxxxxxxxxxxxxxxxx"
|
||||
|
||||
|
||||
def _stripe_sig(body: bytes, secret: str, ts: int | None = None) -> str:
|
||||
"""Produce a Stripe-Signature header matching the bytes signed.
|
||||
Format: `t=<ts>,v1=<hex hmac sha256>` over `<ts>.<body>`."""
|
||||
ts = ts if ts is not None else int(time.time())
|
||||
signed = f"{ts}.{body.decode('utf-8')}"
|
||||
mac = hmac.new(secret.encode("utf-8"), signed.encode("utf-8"),
|
||||
hashlib.sha256).hexdigest()
|
||||
return f"t={ts},v1={mac}"
|
||||
|
||||
|
||||
def _build_app(tmp_path):
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from app import db as db_mod
|
||||
from app.auth import sign_session
|
||||
from app.config import get_settings
|
||||
from app.db import Base
|
||||
from app.models import User
|
||||
from app.routers import stripe_billing as stripe_router
|
||||
|
||||
s = get_settings()
|
||||
s.STRIPE_API_KEY = _API_KEY # type: ignore[misc]
|
||||
s.STRIPE_WEBHOOK_SECRET = _WEBHOOK_SECRET # type: ignore[misc]
|
||||
s.STRIPE_PRICE_MONTHLY = _PRICE_MONTHLY # type: ignore[misc]
|
||||
s.STRIPE_PRICE_ANNUAL = _PRICE_ANNUAL # type: ignore[misc]
|
||||
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/stripe.db")
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
db_mod._engine = engine
|
||||
db_mod._session_factory = factory
|
||||
|
||||
async def _seed():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
async with factory() as session:
|
||||
session.add(User(id=1, email="buyer@x", tier="free"))
|
||||
await session.commit()
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(stripe_router.router)
|
||||
return TestClient(app), factory, sign_session(1)
|
||||
|
||||
|
||||
def _post_webhook(client, *, body: dict, secret: str = _WEBHOOK_SECRET,
|
||||
sig: str | None = None):
|
||||
# Stripe's SDK requires a top-level `object: "event"` field to know
|
||||
# this is a v1 webhook envelope — tests that omit it fail in
|
||||
# construct_event before the signature check matters. We inject the
|
||||
# default here so individual tests can stay terse.
|
||||
body.setdefault("object", "event")
|
||||
raw = json.dumps(body).encode("utf-8")
|
||||
sig = sig if sig is not None else _stripe_sig(raw, secret)
|
||||
return client.post(
|
||||
"/api/stripe/webhook",
|
||||
content=raw,
|
||||
headers={"stripe-signature": sig, "content-type": "application/json"},
|
||||
)
|
||||
|
||||
|
||||
# --- signature gate --------------------------------------------------------
|
||||
|
||||
|
||||
def test_webhook_rejects_bad_signature(tmp_path):
|
||||
client, _, _ = _build_app(tmp_path)
|
||||
raw = json.dumps({"id": "evt_x", "type": "invoice.paid",
|
||||
"data": {"object": {}}}).encode("utf-8")
|
||||
r = client.post(
|
||||
"/api/stripe/webhook",
|
||||
content=raw,
|
||||
headers={
|
||||
"stripe-signature": "t=0,v1=deadbeef",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
)
|
||||
assert r.status_code == 401, r.text
|
||||
|
||||
|
||||
def test_webhook_rejects_missing_signature(tmp_path):
|
||||
client, _, _ = _build_app(tmp_path)
|
||||
r = client.post(
|
||||
"/api/stripe/webhook",
|
||||
content=b"{}",
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
|
||||
|
||||
# --- happy paths -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_checkout_session_completed_flips_tier_to_paid(tmp_path):
|
||||
client, factory, _ = _build_app(tmp_path)
|
||||
body = {
|
||||
"id": "evt_checkout_1",
|
||||
"type": "checkout.session.completed",
|
||||
"data": {
|
||||
"object": {
|
||||
"client_reference_id": "1",
|
||||
"customer": "cus_abc",
|
||||
"subscription": "sub_xyz",
|
||||
}
|
||||
},
|
||||
}
|
||||
r = _post_webhook(client, body=body)
|
||||
assert r.status_code == 200, r.text
|
||||
assert r.json()["status"] == "ok"
|
||||
|
||||
async def _check():
|
||||
from sqlalchemy import select
|
||||
from app.models import User
|
||||
async with factory() as session:
|
||||
u = (await session.execute(
|
||||
select(User).where(User.id == 1)
|
||||
)).scalar_one()
|
||||
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
|
||||
|
||||
tier, cid, sid = asyncio.run(_check())
|
||||
assert tier == "paid"
|
||||
assert cid == "cus_abc"
|
||||
assert sid == "sub_xyz"
|
||||
|
||||
|
||||
def test_subscription_deleted_drops_tier_to_free(tmp_path):
|
||||
client, factory, _ = _build_app(tmp_path)
|
||||
# First, activate.
|
||||
_post_webhook(client, body={
|
||||
"id": "evt_act",
|
||||
"type": "checkout.session.completed",
|
||||
"data": {"object": {
|
||||
"client_reference_id": "1",
|
||||
"customer": "cus_abc",
|
||||
"subscription": "sub_xyz",
|
||||
}},
|
||||
})
|
||||
# Then, delete the subscription.
|
||||
r = _post_webhook(client, body={
|
||||
"id": "evt_del",
|
||||
"type": "customer.subscription.deleted",
|
||||
"data": {"object": {
|
||||
"id": "sub_xyz",
|
||||
"customer": "cus_abc",
|
||||
"status": "canceled",
|
||||
}},
|
||||
})
|
||||
assert r.status_code == 200, r.text
|
||||
|
||||
async def _check():
|
||||
from sqlalchemy import select
|
||||
from app.models import User
|
||||
async with factory() as session:
|
||||
u = (await session.execute(
|
||||
select(User).where(User.id == 1)
|
||||
)).scalar_one()
|
||||
return u.tier, u.stripe_customer_id, u.stripe_subscription_id
|
||||
|
||||
tier, cid, sid = asyncio.run(_check())
|
||||
assert tier == "free"
|
||||
# Customer linkage preserved so a future resub matches this row.
|
||||
assert cid == "cus_abc"
|
||||
assert sid is None
|
||||
|
||||
|
||||
def test_subscription_active_grants_paid(tmp_path):
|
||||
"""customer.subscription.updated with status=active should also
|
||||
grant paid — covers the case where checkout.session.completed
|
||||
arrives after subscription.created and we want either to work."""
|
||||
client, factory, _ = _build_app(tmp_path)
|
||||
# Seed the linkage first via checkout (so customer_id is known).
|
||||
_post_webhook(client, body={
|
||||
"id": "evt_ck",
|
||||
"type": "checkout.session.completed",
|
||||
"data": {"object": {
|
||||
"client_reference_id": "1",
|
||||
"customer": "cus_abc",
|
||||
"subscription": "sub_xyz",
|
||||
}},
|
||||
})
|
||||
# Drop to free manually so we can prove the updated event re-grants.
|
||||
async def _reset():
|
||||
from sqlalchemy import update
|
||||
from app.models import User
|
||||
async with factory() as session:
|
||||
await session.execute(
|
||||
update(User).where(User.id == 1).values(tier="free")
|
||||
)
|
||||
await session.commit()
|
||||
asyncio.run(_reset())
|
||||
|
||||
r = _post_webhook(client, body={
|
||||
"id": "evt_upd",
|
||||
"type": "customer.subscription.updated",
|
||||
"data": {"object": {
|
||||
"id": "sub_xyz",
|
||||
"customer": "cus_abc",
|
||||
"status": "active",
|
||||
}},
|
||||
})
|
||||
assert r.status_code == 200
|
||||
|
||||
async def _check_tier():
|
||||
from sqlalchemy import select
|
||||
from app.models import User
|
||||
async with factory() as session:
|
||||
return (await session.execute(
|
||||
select(User.tier).where(User.id == 1)
|
||||
)).scalar_one()
|
||||
assert asyncio.run(_check_tier()) == "paid"
|
||||
|
||||
|
||||
# --- idempotency + unknown ------------------------------------------------
|
||||
|
||||
|
||||
def test_replayed_event_id_is_a_noop(tmp_path):
|
||||
client, factory, _ = _build_app(tmp_path)
|
||||
body = {
|
||||
"id": "evt_dup",
|
||||
"type": "checkout.session.completed",
|
||||
"data": {"object": {
|
||||
"client_reference_id": "1",
|
||||
"customer": "cus_abc",
|
||||
"subscription": "sub_xyz",
|
||||
}},
|
||||
}
|
||||
r1 = _post_webhook(client, body=body)
|
||||
r2 = _post_webhook(client, body=body)
|
||||
assert r1.json()["status"] == "ok"
|
||||
assert r2.json()["status"] == "duplicate"
|
||||
|
||||
async def _count_rows():
|
||||
from sqlalchemy import select, func
|
||||
from app.models import StripeEvent
|
||||
async with factory() as session:
|
||||
n = (await session.execute(
|
||||
select(func.count(StripeEvent.id))
|
||||
.where(StripeEvent.event_id == "evt_dup")
|
||||
)).scalar_one()
|
||||
return n
|
||||
assert asyncio.run(_count_rows()) == 1
|
||||
|
||||
|
||||
def test_unknown_event_is_acked(tmp_path):
|
||||
client, _, _ = _build_app(tmp_path)
|
||||
r = _post_webhook(client, body={
|
||||
"id": "evt_unknown",
|
||||
"type": "product.something.new",
|
||||
"data": {"object": {}},
|
||||
})
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "ignored"
|
||||
|
||||
|
||||
# --- /api/stripe/checkout (with Stripe SDK mocked) ------------------------
|
||||
|
||||
|
||||
def test_checkout_endpoint_creates_session_and_returns_url(tmp_path):
|
||||
client, _, session_cookie = _build_app(tmp_path)
|
||||
# Mock the Stripe SDK call so no real HTTP goes out.
|
||||
fake_session = SimpleNamespace(
|
||||
id="cs_test_123", url="https://checkout.stripe.com/test",
|
||||
)
|
||||
|
||||
class _FakeSessions:
|
||||
@staticmethod
|
||||
def create(params): # noqa: ANN001
|
||||
assert params["mode"] == "subscription"
|
||||
assert params["line_items"][0]["price"] == _PRICE_MONTHLY
|
||||
assert params["client_reference_id"] == "1"
|
||||
assert params["customer_email"] == "buyer@x"
|
||||
return fake_session
|
||||
|
||||
class _FakeCheckout:
|
||||
sessions = _FakeSessions()
|
||||
|
||||
class _FakeClient:
|
||||
checkout = _FakeCheckout()
|
||||
|
||||
with patch("app.routers.stripe_billing._stripe_client",
|
||||
return_value=_FakeClient()):
|
||||
r = client.post(
|
||||
"/api/stripe/checkout",
|
||||
json={"cadence": "monthly"},
|
||||
cookies={"cassandra_session": session_cookie},
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
assert r.json()["url"] == "https://checkout.stripe.com/test"
|
||||
|
||||
|
||||
def test_checkout_endpoint_requires_login(tmp_path):
|
||||
client, _, _ = _build_app(tmp_path)
|
||||
r = client.post("/api/stripe/checkout", json={"cadence": "monthly"})
|
||||
# No session cookie → require_auth bounces with 401.
|
||||
assert r.status_code == 401, r.text
|
||||
Loading…
Add table
Add a link
Reference in a new issue