"""Polar (Standard Webhooks) endpoint: signature verification, idempotency, and the subscription.active -> tier=paid handler. Integration-style: real router + in-memory aiosqlite. Same scaffold as test_news_window.py / test_chat_and_log_gates.py.""" from __future__ import annotations import asyncio import base64 import hashlib import hmac import json import time import pytest _SECRET_RAW = b"this-is-a-deterministic-test-secret-32b!" _SECRET = "whsec_" + base64.b64encode(_SECRET_RAW).decode("ascii") def _sign(msg_id: str, ts: str, body: bytes) -> str: """Produce the `v1,` token Polar would send.""" signed = f"{msg_id}.{ts}.{body.decode('utf-8')}" mac = hmac.new(_SECRET_RAW, signed.encode("utf-8"), hashlib.sha256).digest() return "v1," + base64.b64encode(mac).decode("ascii") def _build_app(tmp_path): from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from app import db as db_mod from app.config import get_settings from app.db import Base from app.models import User from app.routers import polar_webhook as polar_router # Inject the secret into the cached Settings. We override the # field rather than monkeypatching env because the secret is read # via get_settings() at request time. s = get_settings() s.POLAR_WEBHOOK_SECRET = _SECRET # type: ignore[misc] engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/polar.db") factory = async_sessionmaker(engine, expire_on_commit=False) db_mod._engine = engine db_mod._session_factory = factory async def _seed(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async with factory() as session: session.add(User(id=1, email="paying@x", tier="free")) await session.commit() asyncio.run(_seed()) app = FastAPI() app.include_router(polar_router.router) return TestClient(app), factory def _post(client, *, body: dict, msg_id="msg_001", ts: str | None = None, sig: str | None = None): raw = json.dumps(body).encode("utf-8") ts = ts or str(int(time.time())) sig = sig or _sign(msg_id, ts, raw) return client.post( "/api/polar/webhook", content=raw, headers={ "webhook-id": msg_id, "webhook-timestamp": ts, "webhook-signature": sig, "content-type": "application/json", }, ) # --- signature gate -------------------------------------------------------- def test_rejects_bad_signature(tmp_path): client, _ = _build_app(tmp_path) raw = json.dumps({"type": "subscription.active", "data": {}}).encode("utf-8") ts = str(int(time.time())) r = client.post( "/api/polar/webhook", content=raw, headers={ "webhook-id": "msg_bad", "webhook-timestamp": ts, "webhook-signature": "v1,AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "content-type": "application/json", }, ) assert r.status_code == 401, r.text def test_rejects_stale_timestamp(tmp_path): client, _ = _build_app(tmp_path) body = {"type": "subscription.active", "data": {}} # 10 minutes in the past — beyond the 5-minute tolerance window. stale = str(int(time.time()) - 600) r = _post(client, body=body, ts=stale, msg_id="msg_stale") assert r.status_code == 401, r.text def test_rejects_missing_headers(tmp_path): client, _ = _build_app(tmp_path) r = client.post("/api/polar/webhook", content=b"{}", headers={"content-type": "application/json"}) assert r.status_code == 400, r.text # --- happy paths ----------------------------------------------------------- def test_subscription_active_flips_tier_to_paid(tmp_path): client, factory = _build_app(tmp_path) body = { "type": "subscription.active", "data": { "id": "sub_abc", "customer": {"id": "cust_xyz", "email": "paying@x"}, }, } r = _post(client, body=body, msg_id="msg_active") assert r.status_code == 200, r.text assert r.json()["status"] == "ok" async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.polar_customer_id, u.polar_subscription_id tier, cid, sid = asyncio.run(_check()) assert tier == "paid" assert cid == "cust_xyz" assert sid == "sub_abc" def test_subscription_revoked_drops_to_free(tmp_path): client, factory = _build_app(tmp_path) # First, activate. _post(client, body={ "type": "subscription.active", "data": {"id": "sub_abc", "customer": {"id": "cust_xyz", "email": "paying@x"}}, }, msg_id="msg_act") # Then, revoke. r = _post(client, body={ "type": "subscription.revoked", "data": {"id": "sub_abc", "customer": {"id": "cust_xyz", "email": "paying@x"}}, }, msg_id="msg_rev") assert r.status_code == 200, r.text async def _check(): from sqlalchemy import select from app.models import User async with factory() as session: u = (await session.execute( select(User).where(User.id == 1) )).scalar_one() return u.tier, u.polar_customer_id, u.polar_subscription_id tier, cid, sid = asyncio.run(_check()) assert tier == "free" # Customer linkage preserved so a future resub matches the same row. assert cid == "cust_xyz" assert sid is None # --- idempotency ----------------------------------------------------------- def test_replayed_event_id_is_a_noop(tmp_path): client, factory = _build_app(tmp_path) body = { "type": "subscription.active", "data": {"id": "sub_abc", "customer": {"id": "cust_xyz", "email": "paying@x"}}, } # Two POSTs with the same msg_id and body — second should be deduped. r1 = _post(client, body=body, msg_id="msg_dup") r2 = _post(client, body=body, msg_id="msg_dup") assert r1.status_code == 200 and r1.json()["status"] == "ok" assert r2.status_code == 200 and r2.json()["status"] == "duplicate" async def _count(): from sqlalchemy import select, func from app.models import PolarEvent async with factory() as session: n = (await session.execute( select(func.count(PolarEvent.id)) .where(PolarEvent.event_id == "msg_dup") )).scalar_one() return n assert asyncio.run(_count()) == 1 def test_unknown_event_type_is_acked(tmp_path): client, _ = _build_app(tmp_path) body = {"type": "benefit_grant.cycled", "data": {}} r = _post(client, body=body, msg_id="msg_unknown") assert r.status_code == 200 assert r.json()["status"] == "ignored"