Replace pick_barrier.py with thumbnail-grid UX

Old version showed inter-fly distance plots and asked the analyst to
click a timeline. The new version reads frames directly from the .mp4
and shows a 10×6 grid of timestamped thumbnails — the analyst just
clicks the frame where the barrier opens.

Two-stage refinement:
  - Coarse grid: 60 thumbs spanning the 5-min search window at ~5 s
    spacing. Pick the rough moment.
  - Fine grid: 60 thumbs at 0.2 s spacing centred on the coarse pick.
    Pick the exact frame.

Auto-detector still feeds the starting position. Sequential video
decode (one cv2 pass through the relevant range) instead of seek-per-
frame, so each grid loads in a few seconds.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Giorgio Gilestro 2026-05-01 12:01:34 +01:00
parent b46c4ac1ba
commit e8c7f23d4d

View file

@ -1,47 +1,48 @@
"""Interactive picker for barrier-opening time per tracked video. """Interactive picker for barrier-opening time, frame-by-frame thumbnail style.
Loops through tracked DBs that don't yet have a barrier-opening For each video that doesn't yet have a barrier-opening annotation, show a
annotation. For each, plots the windowed mean inter-fly distance for 10x6 grid of timestamped thumbnails extracted directly from the .mp4.
all 6 ROIs over the first 5 minutes and lets the analyst click the The analyst clicks the thumbnail at (or just after) the moment the
moment the barrier opens (when most flies start coming close together). barrier opens; the picker then refines with a second tighter grid for
The auto-detector's best-effort guess is shown as a starting position. sub-second precision.
Two-stage flow per video:
1. Coarse grid: 60 thumbs spanning the 5-min search window (5 s spacing).
Click pick that 5 s slot.
2. Fine grid: 60 thumbs spanning ±6 s of the coarse pick (0.2 s spacing).
Click final answer with 0.2 s precision.
Output: data/metadata/barrier_opening.csv with columns Output: data/metadata/barrier_opening.csv with columns
machine_name, session_date, session_time, opening_s, trim_first_s, notes machine_name, session_date, session_time, opening_s, trim_first_s, notes
`opening_s` is the moment of barrier opening, measured from the start
of the recording (NOT the start of any trimmed copy). `trim_first_s`
is an optional annotation for videos with a misframed start that
should be ignored by analysis (defaults to 0).
Window keys: Window keys:
click place the opening cursor at that time click select thumbnail at that timestamp
ENTER save and advance n skip this video for THIS run
[, ] shift cursor by 1 s left / right u mark unusable (opening_s = NaN)
{, } shift cursor by 5 s left / right b back to coarse grid (after seeing fine grid)
n skip this video for THIS run (no row written)
u mark this video unusable (writes opening_s = NaN, notes = "unusable")
r reset cursor to the auto-detected position
q / ESC save+quit q / ESC save+quit
Usage: Usage:
python pick_barrier.py python pick_barrier.py
python pick_barrier.py --redo # re-pick videos that already have a row python pick_barrier.py --redo
python pick_barrier.py --limit 10 python pick_barrier.py --limit 10
python pick_barrier.py --db /path/to/specific_tracking.db
""" """
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import re
import sqlite3 import sqlite3
import sys import sys
from pathlib import Path from pathlib import Path
import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from config import DATA_METADATA, VIDEO_INFO_TSV from config import DATA_METADATA, INVENTORY_CSV, VIDEO_INFO_TSV
from detect_barrier_opening import ( from detect_barrier_opening import (
SEARCH_END_S, STEP_S, WINDOW_S, SEARCH_END_S, STEP_S, WINDOW_S,
per_frame_distance, sliding_mean, per_frame_distance, sliding_mean,
@ -51,43 +52,30 @@ OUT_CSV = DATA_METADATA / "barrier_opening.csv"
OUT_COLS = ["machine_name", "session_date", "session_time", OUT_COLS = ["machine_name", "session_date", "session_time",
"opening_s", "trim_first_s", "notes"] "opening_s", "trim_first_s", "notes"]
DB_NAME_RE = re.compile(
r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__"
)
def parse_db_filename(db_path: Path) -> tuple[str, str, str] | None: GRID_ROWS, GRID_COLS = 6, 10
"""Pull (date, time, machine_uuid) out of a tracking DB filename.""" N_THUMBS = GRID_ROWS * GRID_COLS # 60
import re COARSE_SPAN_S = SEARCH_END_S # 0..300s, ~5s spacing
m = re.match( FINE_SPAN_S = 12.0 # ±6s around coarse pick → ~0.2s spacing
r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__",
db_path.name,
) def auto_suggest(db_path: Path) -> float | None:
if not m: """Median of per-ROI biggest-drop times. None if too noisy."""
try:
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
except sqlite3.Error:
return None return None
return m.group(1), m.group(2), m.group(3)
def load_distance_traces(db_path: Path) -> dict[int, pd.DataFrame]:
"""For each ROI 1..6, return windowed-mean DF; empty if ROI missing."""
out: dict[int, pd.DataFrame] = {}
with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn:
for roi in range(1, 7):
try:
df = pd.read_sql_query(
f"SELECT t, x, y, id FROM ROI_{roi}", conn
)
except Exception:
out[roi] = pd.DataFrame()
continue
dist = per_frame_distance(df)
out[roi] = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
return out
def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None:
"""Median of per-ROI biggest-drop times. Returns None if too noisy."""
candidates = [] candidates = []
for roi, smean in traces.items(): for roi in range(1, 7):
if len(smean) < 30: try:
df = pd.read_sql_query(f"SELECT t, x, y, id FROM ROI_{roi}", conn)
except Exception:
continue continue
# Find the time of the largest decrease in median(pre)median(post). dist = per_frame_distance(df)
smean = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
pad = max(1, int(WINDOW_S / STEP_S)) pad = max(1, int(WINDOW_S / STEP_S))
if len(smean) < 2 * pad + 1: if len(smean) < 2 * pad + 1:
continue continue
@ -102,103 +90,200 @@ def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None:
best_t = float(smean["mid_t"].iloc[i]) best_t = float(smean["mid_t"].iloc[i])
if best_drop > 30 and best_t is not None: if best_drop > 30 and best_t is not None:
candidates.append(best_t) candidates.append(best_t)
conn.close()
if not candidates: if not candidates:
return None return None
return float(np.median(candidates)) return float(np.median(candidates))
def show_picker( def grab_thumbnails(video_path: Path, target_times_s: np.ndarray,
db_path: Path, thumb_w: int = 320) -> list[np.ndarray | None]:
machine_name: str, """Read thumbnails at the requested timestamps via a single sequential pass.
session_date: str,
session_time: str,
auto_t: float | None,
initial_t: float,
) -> dict | None:
"""Open the picker window. Returns a dict ready for OUT_CSV, or None to skip."""
traces = load_distance_traces(db_path)
if all(s.empty for s in traces.values()):
print(f" ! no usable ROI traces in {db_path.name}; skipping")
return None
fig, axes = plt.subplots(6, 1, figsize=(13, 12), sharex=True) Linear-decode is much faster than seeking per-frame on H.264. We read
fig.suptitle( frames sequentially from the earliest target onward, keeping only the
f"{machine_name} {session_date} {session_time}\n" ones at requested target frames.
f"click ↦ set opening · ENTER save · " """
f"[/] ±1s · {{/}} ±5s · n skip · u unusable · r reset · q quit", cap = cv2.VideoCapture(str(video_path))
fontsize=10, fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if total_frames <= 0:
cap.release()
return [None] * len(target_times_s)
target_frames = np.clip(
(target_times_s * fps).round().astype(int), 0, total_frames - 1
) )
sort_idx = np.argsort(target_frames)
sorted_targets = target_frames[sort_idx]
state: dict = {"t": float(initial_t), "auto": auto_t, "result": None} out: list[np.ndarray | None] = [None] * len(target_times_s)
if sorted_targets.size == 0:
cap.release()
return out
def redraw(): cap.set(cv2.CAP_PROP_POS_FRAMES, int(sorted_targets[0]))
for ax, (roi, smean) in zip(axes, sorted(traces.items())): cur_frame = int(sorted_targets[0])
ax.cla() last_frame_data: np.ndarray | None = None
if smean.empty:
ax.text(0.5, 0.5, f"ROI {roi}: no data", scale = thumb_w / src_w if src_w > 0 else 1.0
transform=ax.transAxes, ha="center", va="center", color="grey") thumb_h = max(1, int(round(src_h * scale)))
else:
ax.plot(smean["mid_t"], smean["mean_dist"], color="steelblue", lw=1.0) for ord_i, target in zip(sort_idx, sorted_targets):
ax.set_ylabel(f"ROI {roi}") while cur_frame <= target:
if state["auto"] is not None: ret, frame = cap.read()
ax.axvline(state["auto"], color="orange", ls=":", lw=0.8, alpha=0.8) if not ret:
ax.axvline(state["t"], color="red", lw=1.5) last_frame_data = None
ax.set_xlim(0, SEARCH_END_S) break
ax.grid(True, alpha=0.3) last_frame_data = frame
axes[-1].set_xlabel("time (s)") cur_frame += 1
axes[0].set_title(f"orange dotted = auto-suggested · red = current pick: {state['t']:.1f} s", if last_frame_data is not None:
fontsize=9) small = cv2.resize(last_frame_data, (thumb_w, thumb_h),
fig.canvas.draw_idle() interpolation=cv2.INTER_AREA)
out[ord_i] = cv2.cvtColor(small, cv2.COLOR_BGR2RGB)
cap.release()
return out
def show_thumbnail_grid(
video_path: Path,
center_t: float,
span_s: float,
title: str,
) -> tuple[float | None, str]:
"""Show a 10×6 thumbnail grid; return (clicked_time, action).
`action` is one of: 'pick', 'skip', 'unusable', 'back', 'quit'.
`clicked_time` is None unless action == 'pick'.
"""
half = span_s / 2.0
times = np.linspace(max(0.0, center_t - half), center_t + half, N_THUMBS)
print(f" loading {N_THUMBS} thumbnails ({times[0]:.1f}{times[-1]:.1f}s)...", flush=True)
thumbs = grab_thumbnails(video_path, times)
fig, axes = plt.subplots(GRID_ROWS, GRID_COLS, figsize=(20, 11))
fig.suptitle(
f"{title}\nclick a thumbnail · n=skip · u=unusable · b=back · q=quit",
fontsize=11,
)
state = {"time": None, "action": None}
for ax, t, thumb in zip(axes.flat, times, thumbs):
if thumb is not None:
ax.imshow(thumb)
else:
ax.set_facecolor("black")
ax.text(0.5, 0.5, "no frame",
transform=ax.transAxes, ha="center", va="center", color="white")
# Format time as M:SS.s for readability
m, s = divmod(t, 60)
ax.set_title(f"{int(m):d}:{s:05.2f}", fontsize=8, pad=1)
ax.set_xticks([]); ax.set_yticks([])
fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01,
wspace=0.03, hspace=0.18)
def on_click(event): def on_click(event):
if event.inaxes in axes and event.xdata is not None: if event.inaxes is None:
state["t"] = max(0.0, min(SEARCH_END_S, float(event.xdata))) return
redraw() for i, ax in enumerate(axes.flat):
if ax is event.inaxes:
state["time"] = float(times[i])
state["action"] = "pick"
plt.close(fig)
return
def on_key(event): def on_key(event):
k = event.key k = event.key
if k == "enter": if k == "n":
state["result"] = { state["action"] = "skip"; plt.close(fig)
"machine_name": machine_name,
"session_date": session_date,
"session_time": session_time,
"opening_s": round(state["t"], 1),
"trim_first_s": 0,
"notes": "",
}
plt.close(fig)
elif k == "n":
state["result"] = "skip"
plt.close(fig)
elif k == "u": elif k == "u":
state["result"] = { state["action"] = "unusable"; plt.close(fig)
"machine_name": machine_name, elif k == "b":
"session_date": session_date, state["action"] = "back"; plt.close(fig)
"session_time": session_time,
"opening_s": np.nan,
"trim_first_s": 0,
"notes": "unusable",
}
plt.close(fig)
elif k == "r" and state["auto"] is not None:
state["t"] = state["auto"]
redraw()
elif k in ("[", "]"):
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-1 if k == "[" else 1)))
redraw()
elif k in ("{", "}"):
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-5 if k == "{" else 5)))
redraw()
elif k in ("q", "escape"): elif k in ("q", "escape"):
state["result"] = "quit" state["action"] = "quit"; plt.close(fig)
plt.close(fig)
fig.canvas.mpl_connect("button_press_event", on_click) fig.canvas.mpl_connect("button_press_event", on_click)
fig.canvas.mpl_connect("key_press_event", on_key) fig.canvas.mpl_connect("key_press_event", on_key)
redraw()
plt.show() plt.show()
return state["time"], state["action"] or "skip"
return state["result"]
def pick_for_video(
video_path: Path,
db_path: Path | None,
machine_name: str,
session_date: str,
session_time: str,
) -> dict | str | None:
"""Run the two-stage thumbnail picker. Return dict, 'skip', or 'quit'."""
auto_t = auto_suggest(db_path) if db_path else None
print(f" auto-suggest: {f'{auto_t:.1f}s' if auto_t else '(none)'}")
# Stage 1: coarse grid centred on auto-suggest (or 150 s default).
coarse_center = auto_t if auto_t is not None else COARSE_SPAN_S / 2
title_coarse = f"COARSE {machine_name} {session_date} {session_time} · spanning 5 min"
while True:
coarse_t, action = show_thumbnail_grid(
video_path, coarse_center, COARSE_SPAN_S, title_coarse
)
if action == "skip":
return "skip"
if action == "unusable":
return {
"machine_name": machine_name, "session_date": session_date,
"session_time": session_time, "opening_s": np.nan,
"trim_first_s": 0, "notes": "unusable",
}
if action == "quit":
return "quit"
if action == "back":
continue # already at the top stage; redraw
if action == "pick" and coarse_t is not None:
break
# Stage 2: fine grid around the coarse pick.
title_fine = (f"FINE {machine_name} {session_date} {session_time} "
f"· ±{FINE_SPAN_S/2:.0f} s around {coarse_t:.1f} s")
while True:
fine_t, action = show_thumbnail_grid(
video_path, coarse_t, FINE_SPAN_S, title_fine
)
if action == "back":
return pick_for_video(video_path, db_path, machine_name,
session_date, session_time)
if action == "skip":
return "skip"
if action == "unusable":
return {
"machine_name": machine_name, "session_date": session_date,
"session_time": session_time, "opening_s": np.nan,
"trim_first_s": 0, "notes": "unusable",
}
if action == "quit":
return "quit"
if action == "pick" and fine_t is not None:
return {
"machine_name": machine_name, "session_date": session_date,
"session_time": session_time, "opening_s": round(fine_t, 1),
"trim_first_s": 0, "notes": "",
}
def lookup_video_path(machine_name: str, session_date: str,
session_time: str, inv: pd.DataFrame) -> Path | None:
"""Find the mp4 path for (machine, date, time) in the inventory."""
match = inv[
(inv["machine_name"] == machine_name)
& (inv["session_date"] == session_date)
& (inv["session_time"] == session_time)
]
if match.empty:
return None
return Path(match.iloc[0]["mp4_path"])
def main() -> None: def main() -> None:
@ -208,7 +293,7 @@ def main() -> None:
parser.add_argument("--limit", type=int, default=None, parser.add_argument("--limit", type=int, default=None,
help="only process the first N videos") help="only process the first N videos")
parser.add_argument("--db", type=Path, default=None, parser.add_argument("--db", type=Path, default=None,
help="annotate this specific DB only") help="annotate this specific tracking DB only")
args = parser.parse_args() args = parser.parse_args()
OUT_CSV.parent.mkdir(parents=True, exist_ok=True) OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
@ -218,10 +303,15 @@ def main() -> None:
out = pd.DataFrame(columns=OUT_COLS) out = pd.DataFrame(columns=OUT_COLS)
done = set(zip(out["machine_name"], out["session_date"], out["session_time"])) done = set(zip(out["machine_name"], out["session_date"], out["session_time"]))
# Build the queue: every tracked DB referenced by the merged TSV that if not INVENTORY_CSV.exists():
# hasn't been picked yet. sys.exit(f"Inventory not found at {INVENTORY_CSV}. Run build_video_inventory.py first.")
inv = pd.read_csv(INVENTORY_CSV)
# Build the queue: every (machine, date, time) referenced by the merged
# TSV that has a tracking DB on disk and isn't yet annotated.
tsv = pd.read_csv(VIDEO_INFO_TSV, sep="\t") tsv = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
queue = [] queue: list[tuple[Path, Path, str, str, str]] = []
seen: set[tuple[str, str, str]] = set()
for col in ("training_db_path", "testing_db_path"): for col in ("training_db_path", "testing_db_path"):
for _, row in tsv.iterrows(): for _, row in tsv.iterrows():
db = row[col] db = row[col]
@ -230,28 +320,25 @@ def main() -> None:
db_path = Path(db) db_path = Path(db)
if not db_path.exists(): if not db_path.exists():
continue continue
parsed = parse_db_filename(db_path) m = DB_NAME_RE.match(db_path.name)
if parsed is None: if not m:
continue continue
session_date, session_time, _ = parsed session_date, session_time = m.group(1), m.group(2)
key = (row["machine_name"], session_date, session_time) key = (row["machine_name"], session_date, session_time)
if key in seen:
continue
seen.add(key)
if key in done and not args.redo: if key in done and not args.redo:
continue continue
queue.append((db_path, row["machine_name"], session_date, session_time)) video = lookup_video_path(*key, inv)
if video is None or not video.exists():
# Dedup (a fly may reference the same DB for both training & testing). print(f" ! no video for {key}; skipping")
seen = set() continue
deduped = [] queue.append((db_path, video, *key))
for item in queue:
k = (item[1], item[2], item[3])
if k not in seen:
seen.add(k)
deduped.append(item)
queue = deduped
if args.db: if args.db:
target = Path(args.db).resolve() target = Path(args.db).resolve()
queue = [q for q in queue if Path(q[0]).resolve() == target] queue = [q for q in queue if q[0].resolve() == target]
if not queue: if not queue:
sys.exit(f"DB not found in queue: {args.db}") sys.exit(f"DB not found in queue: {args.db}")
@ -259,34 +346,26 @@ def main() -> None:
queue = queue[: args.limit] queue = queue[: args.limit]
if not queue: if not queue:
print("Nothing to pick. All eligible DBs already have a barrier_opening row.") print("Nothing to pick. All eligible videos already have a barrier_opening row.")
return return
print(f"Picking barrier-opening for {len(queue)} videos.") print(f"Picking barrier-opening for {len(queue)} videos.")
print("Window keys: click=set ENTER=save [/]=±1s {/}=±5s n=skip u=unusable r=reset q=quit") print("Window keys: click=pick · n=skip · u=unusable · b=back · q=quit")
saved = skipped = unusable = 0 saved = skipped = unusable = 0
for i, (db, machine_name, session_date, session_time) in enumerate(queue, 1): for i, (db, video, machine_name, session_date, session_time) in enumerate(queue, 1):
prefix = f"[{i}/{len(queue)}] {machine_name} {session_date} {session_time}" prefix = f"[{i}/{len(queue)}] {machine_name} {session_date} {session_time}"
print(f"\n{prefix}") print(f"\n{prefix}")
traces = load_distance_traces(db) result = pick_for_video(video, db, machine_name, session_date, session_time)
auto_t = auto_suggest(traces)
initial = auto_t if auto_t is not None else 60.0
print(f" auto-suggest: "
f"{f'{auto_t:.1f}s' if auto_t is not None else '(none)'}")
result = show_picker(db, machine_name, session_date, session_time,
auto_t, initial)
if result is None or result == "skip": if result is None or result == "skip":
skipped += 1 skipped += 1
continue continue
if result == "quit": if result == "quit":
print(" quit requested — saving what we have and exiting") print(" quit requested — saving and exiting")
break break
# Append + dedup on key + persist after each save (crash-safe).
new_row = pd.DataFrame([result]) new_row = pd.DataFrame([result])
out = pd.concat([ out = pd.concat([
out[~((out.machine_name == result["machine_name"]) & out[~((out.machine_name == result["machine_name"]) &
@ -297,6 +376,7 @@ def main() -> None:
out[OUT_COLS].to_csv(OUT_CSV, index=False) out[OUT_COLS].to_csv(OUT_CSV, index=False)
if pd.isna(result["opening_s"]): if pd.isna(result["opening_s"]):
unusable += 1 unusable += 1
print(" saved as unusable")
else: else:
saved += 1 saved += 1
print(f" saved opening_s = {result['opening_s']} s") print(f" saved opening_s = {result['opening_s']} s")