cupido/scripts/barrier_picker_app/app.py
Giorgio Gilestro 2623df4172 Picker: identify the analyst (initials) per pick
Each annotation row now carries an `analyst` column. On first visit the
web picker shows a small login modal asking for initials, persists them
in localStorage, and shows the badge in the top-right. Click the badge
to change identities. Submissions without initials are rejected by the
backend (HTTP 400). Skip remains analyst-free.

Backfill: every existing barrier_opening.csv row marked as `GG` since
all current picks were done by Giorgio.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-01 14:23:57 +01:00

349 lines
13 KiB
Python

"""FastAPI server for the web-based barrier-opening picker.
Browse to http://<host>:8000/ and you'll see a video player loaded with
the next un-annotated video from the queue. Use the arrow keys to
scrub (←/→ ±5 s, Shift+←/→ ±1 s, Ctrl+←/→ ±0.1 s), space to pause/play,
or click the seekbar. When the barrier opens, click one of:
[All barriers open] every ROI is usable post-opening
[Upper barriers open] only ROIs 1,3,5 are usable
[Lower barriers open] only ROIs 2,4,6 are usable
The current playhead time is recorded as the barrier-opening moment;
ROI inclusion is set accordingly. There is also a Skip and a Mark
unusable button.
The queue is built from the merged TSV plus the inventory: every
unique (machine, date, time) that has both a tracking DB and an mp4
on disk and is not yet in barrier_opening.csv. Submissions persist
to barrier_opening.csv after every click — refresh-safe.
Configuration (environment variables):
CUPIDO_DATA_VOLUME /mnt/data/projects/cupido (data volume)
CUPIDO_INVENTORY_CSV /cupido/data/metadata/video_inventory.csv
CUPIDO_OUTPUT_CSV /cupido/data/metadata/barrier_opening.csv
"""
from __future__ import annotations
import os
import re
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, Response
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
# ─── Config ──────────────────────────────────────────────────────────────
DATA_VOLUME = Path(os.environ.get("CUPIDO_DATA_VOLUME", "/mnt/data/projects/cupido"))
INVENTORY_CSV = Path(os.environ.get(
"CUPIDO_INVENTORY_CSV", "/cupido/data/metadata/video_inventory.csv"
))
OUTPUT_CSV = Path(os.environ.get(
"CUPIDO_OUTPUT_CSV", "/cupido/data/metadata/barrier_opening.csv"
))
TSV_PATH = DATA_VOLUME / "all_video_info_merged.tsv"
# Reason: the (date, time, machine_uuid) prefix encoded in every tracking
# DB filename and every inventory mp4 filename.
DB_NAME_RE = re.compile(
r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__"
)
OUT_COLS = ["machine_name", "session_date", "session_time",
"opening_s", "trim_first_s", "bad_rois", "analyst", "notes"]
# ROI numbering in the HD mating arena (verified via tracking_geometry):
# upper row = ROIs 1, 3, 5 (y ≈ 0.125)
# lower row = ROIs 2, 4, 6 (y ≈ 0.795)
ROIS_UPPER = "1,3,5"
ROIS_LOWER = "2,4,6"
@dataclass(frozen=True)
class QueueItem:
idx: int
machine_name: str
session_date: str
session_time: str
mp4_path: str
duration_s: float | None
done: bool
metadata: dict # experimental fields aggregated from the merged TSV
# ─── Queue building ─────────────────────────────────────────────────────
_META_FIELDS = (
"species", "training_length_hr", "consolidation_length_hr",
"memory", "age", "training_date_time", "testing_date_time",
)
def _aggregate_metadata(rows: pd.DataFrame, db_filename: str) -> dict:
"""Pull the experimental metadata for one video from its TSV rows.
Most fields are uniform across the 6 ROIs of a video so the first-row
value is representative. `male` is a per-fly label, so we summarise
counts. `session_role` flags whether this video was the training or
testing session for the flies in it.
"""
if rows.empty:
return {}
# Reason: the merged xlsx/TSV currently has duplicate rows per
# (date, machine, ROI). De-dup on those keys so the male counts and
# any per-ROI fields aren't doubled.
if {"date", "machine_name", "roi"}.issubset(rows.columns):
rows = rows.drop_duplicates(subset=["date", "machine_name", "roi"])
r0 = rows.iloc[0]
meta = {}
for f in _META_FIELDS:
v = r0.get(f)
if pd.isna(v):
meta[f] = None
else:
meta[f] = v if isinstance(v, str) else (
int(v) if isinstance(v, float) and v.is_integer() else v
)
# Per-ROI tally.
if "male" in rows.columns:
m = rows["male"].dropna()
meta["n_trained"] = int((m == "trained").sum())
meta["n_naive"] = int((m == "naive").sum())
# Was this the training session, the testing session, or both?
is_training = rows["training_db_path"].astype(str).str.endswith(db_filename).any()
is_testing = rows["testing_db_path"].astype(str).str.endswith(db_filename).any()
if is_training and is_testing:
meta["session_role"] = "training+testing"
elif is_training:
meta["session_role"] = "training"
elif is_testing:
meta["session_role"] = "testing"
else:
meta["session_role"] = "?"
return meta
def _build_queue() -> list[QueueItem]:
"""Build the ordered queue of pickable videos."""
if not TSV_PATH.exists():
raise RuntimeError(f"merged TSV not found at {TSV_PATH}")
if not INVENTORY_CSV.exists():
raise RuntimeError(f"inventory not found at {INVENTORY_CSV}")
tsv = pd.read_csv(TSV_PATH, sep="\t")
inv = pd.read_csv(INVENTORY_CSV)
inv_by_key: dict[tuple[str, str, str], dict] = {}
for r in inv.itertuples(index=False):
inv_by_key[(r.machine_name, r.session_date, r.session_time)] = {
"mp4_path": r.mp4_path,
"duration_s": float(r.duration_s) if pd.notna(r.duration_s) else None,
}
if OUTPUT_CSV.exists():
out = pd.read_csv(OUTPUT_CSV)
done_keys = set(zip(out["machine_name"],
out["session_date"],
out["session_time"]))
else:
done_keys = set()
seen: set[tuple[str, str, str]] = set()
items: list[QueueItem] = []
for col in ("training_db_path", "testing_db_path"):
for row in tsv.itertuples(index=False):
db = getattr(row, col)
if not isinstance(db, str) or not db:
continue
db_path = Path(db)
if not db_path.exists():
continue
m = DB_NAME_RE.match(db_path.name)
if not m:
continue
session_date, session_time = m.group(1), m.group(2)
key = (row.machine_name, session_date, session_time)
if key in seen:
continue
seen.add(key)
inv_row = inv_by_key.get(key)
if inv_row is None or not Path(inv_row["mp4_path"]).exists():
continue
# Reason: gather all TSV rows that reference this video — there
# are typically 6 ROI-rows per session, sometimes also rows
# using it as both training AND testing.
db_filename = db_path.name
related = tsv[
tsv["training_db_path"].astype(str).str.endswith(db_filename)
| tsv["testing_db_path"].astype(str).str.endswith(db_filename)
]
metadata = _aggregate_metadata(related, db_filename)
items.append(QueueItem(
idx=len(items),
machine_name=row.machine_name,
session_date=session_date,
session_time=session_time,
mp4_path=inv_row["mp4_path"],
duration_s=inv_row["duration_s"],
done=key in done_keys,
metadata=metadata,
))
return items
# ─── App ───────────────────────────────────────────────────────────────
app = FastAPI(title="Cupido barrier-opening picker")
STATIC_DIR = Path(__file__).parent / "static"
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@app.get("/")
async def index() -> FileResponse:
return FileResponse(STATIC_DIR / "index.html")
@app.get("/api/queue")
async def get_queue() -> JSONResponse:
queue = _build_queue()
return JSONResponse([
{
"idx": q.idx,
"machine_name": q.machine_name,
"session_date": q.session_date,
"session_time": q.session_time,
"duration_s": q.duration_s,
"done": q.done,
"metadata": q.metadata,
}
for q in queue
])
def _stream_video(file_path: Path, request: Request) -> Response:
"""HTTP Range-aware video streaming."""
file_size = file_path.stat().st_size
range_header = request.headers.get("range")
if range_header is None:
return FileResponse(file_path, media_type="video/mp4",
headers={"Accept-Ranges": "bytes"})
# Parse "bytes=START-END" (END optional)
m = re.match(r"bytes=(\d+)-(\d*)", range_header)
if not m:
raise HTTPException(status_code=416, detail="bad Range header")
start = int(m.group(1))
end = int(m.group(2)) if m.group(2) else file_size - 1
end = min(end, file_size - 1)
if start > end:
raise HTTPException(status_code=416, detail="range not satisfiable")
chunk_size = end - start + 1
def iterfile():
with open(file_path, "rb") as f:
f.seek(start)
remaining = chunk_size
while remaining > 0:
buf = f.read(min(64 * 1024, remaining))
if not buf:
break
yield buf
remaining -= len(buf)
return Response(
content=b"".join(iterfile()),
status_code=206,
media_type="video/mp4",
headers={
"Content-Range": f"bytes {start}-{end}/{file_size}",
"Accept-Ranges": "bytes",
"Content-Length": str(chunk_size),
},
)
@app.get("/api/video/{idx}")
async def get_video(idx: int, request: Request) -> Response:
queue = _build_queue()
if not 0 <= idx < len(queue):
raise HTTPException(status_code=404, detail="idx out of range")
return _stream_video(Path(queue[idx].mp4_path), request)
class Submission(BaseModel):
idx: int
time_s: float | None # None when marking unusable
mode: str # "all" | "upper" | "lower" | "unusable" | "skip"
analyst: str = "" # initials of the human picker (required, non-skip)
notes: str = ""
@app.post("/api/submit")
async def submit(payload: Submission) -> dict:
queue = _build_queue()
if not 0 <= payload.idx < len(queue):
raise HTTPException(status_code=404, detail="idx out of range")
item = queue[payload.idx]
if payload.mode == "skip":
return {"status": "skipped"}
analyst = payload.analyst.strip().upper()
if not analyst:
raise HTTPException(status_code=400, detail="analyst initials required")
if payload.mode == "unusable":
row = {
"machine_name": item.machine_name,
"session_date": item.session_date,
"session_time": item.session_time,
"opening_s": float("nan"),
"trim_first_s": 0,
"bad_rois": "",
"analyst": analyst,
"notes": payload.notes or "unusable",
}
else:
if payload.time_s is None:
raise HTTPException(status_code=400, detail="time_s required")
bad_rois = {
"all": "",
"upper": ROIS_LOWER, # upper-only opens → lower row is bad
"lower": ROIS_UPPER, # lower-only opens → upper row is bad
}.get(payload.mode)
if bad_rois is None:
raise HTTPException(status_code=400, detail=f"unknown mode: {payload.mode}")
row = {
"machine_name": item.machine_name,
"session_date": item.session_date,
"session_time": item.session_time,
"opening_s": round(payload.time_s, 1),
"trim_first_s": 0,
"bad_rois": bad_rois,
"analyst": analyst,
"notes": payload.notes,
}
OUTPUT_CSV.parent.mkdir(parents=True, exist_ok=True)
if OUTPUT_CSV.exists():
out = pd.read_csv(OUTPUT_CSV)
else:
out = pd.DataFrame(columns=OUT_COLS)
for col in OUT_COLS:
if col not in out.columns:
out[col] = ""
# Replace any existing row for this key.
mask = ~((out["machine_name"] == row["machine_name"])
& (out["session_date"] == row["session_date"])
& (out["session_time"] == row["session_time"]))
out = pd.concat([out[mask], pd.DataFrame([row])], ignore_index=True)
out[OUT_COLS].to_csv(OUTPUT_CSV, index=False)
return {"status": "saved", "row": row}
if __name__ == "__main__":
import uvicorn
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
uvicorn.run("app:app", host=host, port=port, reload=False)