"""Email-OTP generation & verification. A code is a 6-digit numeric string (000000–999999). We store an argon2 hash so leaking the DB alone doesn't reveal active codes. Each code has a 15-minute TTL and 5 attempts before it gets marked dead. Generating a new code for an email invalidates any earlier unused ones (one valid code at a time per email). Rate limit: at most one new code per 60 seconds per email. Prevents an attacker spamming the user's inbox via the /resend endpoint. """ from __future__ import annotations import secrets from datetime import datetime, timedelta, timezone from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from sqlalchemy import desc, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.db import utcnow from app.models import EmailOTP def _as_utc(d: datetime) -> datetime: """MariaDB returns naive datetimes — tag them UTC so arithmetic with tz-aware utcnow() doesn't blow up.""" return d if d.tzinfo is not None else d.replace(tzinfo=timezone.utc) _HASHER = PasswordHasher() OTP_LENGTH = 6 OTP_TTL_MINUTES = 15 MAX_ATTEMPTS = 5 RESEND_COOLDOWN_SECONDS = 60 class OTPError(Exception): """User-safe error message for OTP failures.""" def _generate_code() -> str: return f"{secrets.randbelow(10 ** OTP_LENGTH):0{OTP_LENGTH}d}" def _hash_code(code: str) -> str: return _HASHER.hash(code) def _check_code(code: str, hashed: str) -> bool: try: _HASHER.verify(hashed, code) return True except VerifyMismatchError: return False except Exception: return False async def _latest_unused(session: AsyncSession, email: str) -> EmailOTP | None: return (await session.execute( select(EmailOTP) .where(EmailOTP.email == email) .where(EmailOTP.used_at.is_(None)) .order_by(desc(EmailOTP.created_at)) .limit(1) )).scalar_one_or_none() async def can_request_new(session: AsyncSession, email: str) -> tuple[bool, int]: """Returns (allowed, seconds_until_allowed).""" latest = await _latest_unused(session, email) if latest is None: return True, 0 age = (utcnow() - _as_utc(latest.created_at)).total_seconds() if age >= RESEND_COOLDOWN_SECONDS: return True, 0 return False, int(RESEND_COOLDOWN_SECONDS - age) async def issue( session: AsyncSession, email: str, *, purpose: str = "signup", ) -> str: """Generate a fresh code, persist its hash, invalidate any prior unused codes for this email. Returns the plaintext code so the caller can mail it. Caller is responsible for rate-limit check via can_request_new().""" email = email.strip().lower() # Invalidate prior unused codes for this email so only one is valid. await session.execute( update(EmailOTP) .where(EmailOTP.email == email) .where(EmailOTP.used_at.is_(None)) .values(used_at=utcnow()) ) code = _generate_code() now = utcnow() row = EmailOTP( email=email, code_hash=_hash_code(code), created_at=now, expires_at=now + timedelta(minutes=OTP_TTL_MINUTES), attempts=0, purpose=purpose, ) session.add(row) await session.commit() return code async def verify( session: AsyncSession, email: str, code: str, ) -> bool: """Validate the user-submitted code against the latest unused OTP for this email. On success, mark the OTP used. Raises OTPError on user- facing failures (expired, too many attempts, no code outstanding).""" email = email.strip().lower() code = code.strip() if not (code.isdigit() and len(code) == OTP_LENGTH): raise OTPError("Code must be a 6-digit number") latest = await _latest_unused(session, email) if latest is None: raise OTPError("No verification code outstanding for this email") if _as_utc(latest.expires_at) < utcnow(): latest.used_at = utcnow() await session.commit() raise OTPError("This code has expired — request a new one") if latest.attempts >= MAX_ATTEMPTS: latest.used_at = utcnow() await session.commit() raise OTPError("Too many attempts — request a new code") if not _check_code(code, latest.code_hash): latest.attempts += 1 await session.commit() remaining = MAX_ATTEMPTS - latest.attempts if remaining <= 0: raise OTPError("Too many attempts — request a new code") raise OTPError(f"Incorrect code ({remaining} attempts left)") latest.used_at = utcnow() await session.commit() return True