"""Polar (merchant-of-record) webhook endpoint. Polar uses the Standard Webhooks spec (https://www.standardwebhooks.com). Every delivery carries three headers: webhook-id — unique ID for THIS delivery (use for idempotency). webhook-timestamp — Unix seconds at send time (use for replay defence). webhook-signature — space-separated list of `v1,` tokens. Verifying any one of them means the payload is authentic. The signed content is the literal string `{id}.{timestamp}.{body}`, signed with the raw secret bytes (the secret is base64-encoded after the `whsec_` prefix). We verify in constant time and reject anything that doesn't match — including stale deliveries older than ±5 minutes — before parsing JSON or touching the database. Idempotency is keyed on `webhook-id` via a unique constraint on `polar_events.event_id`. A second delivery of the same id finds the row already there and returns 200 without re-running the handler — Polar will retry on non-2xx, so we must always 2xx after a successful first processing. The router is mounted without the app's bearer-token dependency: webhook authenticity is established via the HMAC, not the token.""" from __future__ import annotations import base64 import hashlib import hmac import json import time from datetime import datetime, timezone from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings from app.db import get_session, utcnow from app.logging import get_logger from app.models import PolarEvent, User log = get_logger("polar_webhook") router = APIRouter() # Max clock skew we'll tolerate on the `webhook-timestamp` header. Standard # Webhooks recommends ±5 min; anything older is almost certainly replay. _TIMESTAMP_TOLERANCE_S = 300 # Cap stored payload at 16 KiB so a hostile (or buggy) sender can't blow # up a single row. _PAYLOAD_STORE_MAX = 16 * 1024 def _decode_secret(secret: str) -> bytes: """Polar/Standard-Webhooks secrets are base64 with a `whsec_` prefix. Returns the raw HMAC key. Raises ValueError on malformed input.""" if not secret: raise ValueError("empty webhook secret") s = secret if s.startswith("whsec_"): s = s[len("whsec_"):] return base64.b64decode(s) def _compute_signature(key: bytes, signed_payload: str) -> str: """Return `v1,` — the format a single signature token uses.""" mac = hmac.new(key, signed_payload.encode("utf-8"), hashlib.sha256).digest() return "v1," + base64.b64encode(mac).decode("ascii") def verify_standard_webhook( *, secret: str, msg_id: str, msg_timestamp: str, msg_signature: str, body: bytes, now: float | None = None, ) -> None: """Verify a Standard Webhooks delivery. Raises HTTPException(401) on any failure. No return value — success is "did not raise".""" try: key = _decode_secret(secret) except (ValueError, base64.binascii.Error) as e: raise HTTPException(status_code=500, detail=f"bad webhook secret: {e}") # Timestamp / replay window. try: ts = int(msg_timestamp) except ValueError: raise HTTPException(status_code=401, detail="invalid timestamp") drift = abs((now if now is not None else time.time()) - ts) if drift > _TIMESTAMP_TOLERANCE_S: raise HTTPException(status_code=401, detail="stale timestamp") signed_payload = f"{msg_id}.{msg_timestamp}.{body.decode('utf-8')}" expected = _compute_signature(key, signed_payload) # The header can carry several space-separated tokens (key rotation). # Any match — in constant time — is success. candidates = msg_signature.split() if not any(hmac.compare_digest(expected, c) for c in candidates): raise HTTPException(status_code=401, detail="bad signature") # --------------------------------------------------------------------------- # Event handlers # --------------------------------------------------------------------------- def _customer_id_from_payload(payload_data: dict[str, Any]) -> str | None: """Polar nests the customer object under `customer`. Some events also surface `customer_id` at the top of `data` — accept either.""" cust = payload_data.get("customer") or {} return cust.get("id") or payload_data.get("customer_id") def _customer_email_from_payload(payload_data: dict[str, Any]) -> str | None: cust = payload_data.get("customer") or {} return cust.get("email") async def _find_user(session: AsyncSession, data: dict[str, Any]) -> User | None: """Locate the User row that owns this event. Strategy: join by stored Polar customer id first (the only stable link once we've seen a user). Fall back to email — the first time Polar fires an event for a brand-new customer, we won't have the id yet, but the customer record on Polar's side was created with the user's email by our checkout call.""" cid = _customer_id_from_payload(data) if cid: row = (await session.execute( select(User).where(User.polar_customer_id == cid) )).scalar_one_or_none() if row is not None: return row email = _customer_email_from_payload(data) if email: row = (await session.execute( select(User).where(User.email == email) )).scalar_one_or_none() return row return None async def _grant_paid( session: AsyncSession, user: User, data: dict[str, Any], ) -> None: """Flip the user to the paid tier and persist the Polar IDs we now know. Safe to call repeatedly: tier is idempotent and the IDs only change if Polar issued new ones.""" user.tier = "paid" cid = _customer_id_from_payload(data) if cid and user.polar_customer_id != cid: user.polar_customer_id = cid sub_id = data.get("id") # subscription event payloads put sub id at top if sub_id and user.polar_subscription_id != sub_id: user.polar_subscription_id = sub_id async def _revoke_paid(session: AsyncSession, user: User) -> None: """Drop the user back to the free tier. We deliberately leave the polar_customer_id in place so a re-subscription matches them back to the same row.""" user.tier = "free" user.polar_subscription_id = None async def _handle_subscription_active( session: AsyncSession, data: dict[str, Any], event_type: str, ) -> None: user = await _find_user(session, data) if user is None: log.warning("polar.user_not_found", event_type=event_type, customer_id=_customer_id_from_payload(data)) return await _grant_paid(session, user, data) async def _handle_subscription_revoked( session: AsyncSession, data: dict[str, Any], event_type: str, ) -> None: user = await _find_user(session, data) if user is None: log.warning("polar.user_not_found", event_type=event_type, customer_id=_customer_id_from_payload(data)) return await _revoke_paid(session, user) async def _handle_no_state_change( session: AsyncSession, data: dict[str, Any], event_type: str, ) -> None: """For events we want to record in the audit table but where the tier doesn't move — canceled (still active until period end), uncanceled, past_due, order events, refund created. The PolarEvent row is the record.""" return None # Map event type → handler. Anything not in this map is acknowledged # (200) but ignored, on the principle that Polar may add new event types # over time and we don't want to start 4xx-ing on unknown ones. _HANDLERS = { "subscription.created": _handle_subscription_active, "subscription.active": _handle_subscription_active, "subscription.updated": _handle_subscription_active, "subscription.uncanceled": _handle_subscription_active, "subscription.canceled": _handle_no_state_change, "subscription.revoked": _handle_subscription_revoked, "subscription.past_due": _handle_no_state_change, "order.paid": _handle_no_state_change, "order.refunded": _handle_no_state_change, "refund.created": _handle_no_state_change, } # --------------------------------------------------------------------------- # Endpoint # --------------------------------------------------------------------------- @router.post("/api/polar/webhook") async def polar_webhook( request: Request, session: AsyncSession = Depends(get_session), ) -> dict[str, str]: s = get_settings() if not s.POLAR_WEBHOOK_SECRET: # Loud failure rather than accepting an unsigned event. raise HTTPException(status_code=503, detail="webhook not configured") msg_id = request.headers.get("webhook-id", "") msg_ts = request.headers.get("webhook-timestamp", "") msg_sig = request.headers.get("webhook-signature", "") if not (msg_id and msg_ts and msg_sig): raise HTTPException(status_code=400, detail="missing standard-webhooks headers") body = await request.body() verify_standard_webhook( secret=s.POLAR_WEBHOOK_SECRET, msg_id=msg_id, msg_timestamp=msg_ts, msg_signature=msg_sig, body=body, ) try: envelope = json.loads(body) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="invalid JSON") event_type = envelope.get("type") or "unknown" data = envelope.get("data") or {} # Idempotency: insert the audit row first. If the webhook-id was # already delivered, the UNIQUE constraint short-circuits with a # 200 (Polar will stop retrying). body_text = body.decode("utf-8", errors="replace")[:_PAYLOAD_STORE_MAX] audit = PolarEvent( event_id=msg_id, event_type=event_type, received_at=utcnow(), payload=body_text, ) session.add(audit) try: await session.flush() except IntegrityError: # Already processed — return 200 so Polar doesn't keep retrying. await session.rollback() log.info("polar.duplicate_delivery", event_id=msg_id, type=event_type) return {"status": "duplicate"} handler = _HANDLERS.get(event_type) if handler is None: # Unknown but well-signed event — record it, ack 200. audit.processed_at = utcnow() await session.commit() log.info("polar.event_unhandled", type=event_type, id=msg_id) return {"status": "ignored"} try: await handler(session, data, event_type) except Exception as e: # Mark as errored so an operator can see what's stuck, then # commit + ack 200. We do NOT want Polar to retry an event that # broke handler logic — the same code will break the same way. # Operator gets paged from the error column instead. audit.error = str(e)[:1024] await session.commit() log.exception("polar.handler_error", type=event_type, id=msg_id) return {"status": "handler_error"} audit.processed_at = utcnow() await session.commit() log.info("polar.processed", type=event_type, id=msg_id) return {"status": "ok"}