From e8c7f23d4d8f7d830bbd998ea8a248da982a179c Mon Sep 17 00:00:00 2001 From: Giorgio Gilestro Date: Fri, 1 May 2026 12:01:34 +0100 Subject: [PATCH] Replace pick_barrier.py with thumbnail-grid UX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Old version showed inter-fly distance plots and asked the analyst to click a timeline. The new version reads frames directly from the .mp4 and shows a 10×6 grid of timestamped thumbnails — the analyst just clicks the frame where the barrier opens. Two-stage refinement: - Coarse grid: 60 thumbs spanning the 5-min search window at ~5 s spacing. Pick the rough moment. - Fine grid: 60 thumbs at 0.2 s spacing centred on the coarse pick. Pick the exact frame. Auto-detector still feeds the starting position. Sequential video decode (one cv2 pass through the relevant range) instead of seek-per- frame, so each grid loads in a few seconds. Co-Authored-By: Claude Opus 4.7 --- scripts/pick_barrier.py | 402 ++++++++++++++++++++++++---------------- 1 file changed, 241 insertions(+), 161 deletions(-) diff --git a/scripts/pick_barrier.py b/scripts/pick_barrier.py index cb24f47..06b4431 100644 --- a/scripts/pick_barrier.py +++ b/scripts/pick_barrier.py @@ -1,47 +1,48 @@ -"""Interactive picker for barrier-opening time per tracked video. +"""Interactive picker for barrier-opening time, frame-by-frame thumbnail style. -Loops through tracked DBs that don't yet have a barrier-opening -annotation. For each, plots the windowed mean inter-fly distance for -all 6 ROIs over the first 5 minutes and lets the analyst click the -moment the barrier opens (when most flies start coming close together). -The auto-detector's best-effort guess is shown as a starting position. +For each video that doesn't yet have a barrier-opening annotation, show a +10x6 grid of timestamped thumbnails extracted directly from the .mp4. +The analyst clicks the thumbnail at (or just after) the moment the +barrier opens; the picker then refines with a second tighter grid for +sub-second precision. + +Two-stage flow per video: + 1. Coarse grid: 60 thumbs spanning the 5-min search window (5 s spacing). + Click → pick that 5 s slot. + 2. Fine grid: 60 thumbs spanning ±6 s of the coarse pick (0.2 s spacing). + Click → final answer with 0.2 s precision. Output: data/metadata/barrier_opening.csv with columns machine_name, session_date, session_time, opening_s, trim_first_s, notes -`opening_s` is the moment of barrier opening, measured from the start -of the recording (NOT the start of any trimmed copy). `trim_first_s` -is an optional annotation for videos with a misframed start that -should be ignored by analysis (defaults to 0). - Window keys: - click place the opening cursor at that time - ENTER save and advance - [, ] shift cursor by 1 s left / right - {, } shift cursor by 5 s left / right - n skip this video for THIS run (no row written) - u mark this video unusable (writes opening_s = NaN, notes = "unusable") - r reset cursor to the auto-detected position + click select thumbnail at that timestamp + n skip this video for THIS run + u mark unusable (opening_s = NaN) + b back to coarse grid (after seeing fine grid) q / ESC save+quit Usage: python pick_barrier.py - python pick_barrier.py --redo # re-pick videos that already have a row + python pick_barrier.py --redo python pick_barrier.py --limit 10 + python pick_barrier.py --db /path/to/specific_tracking.db """ from __future__ import annotations import argparse +import re import sqlite3 import sys from pathlib import Path +import cv2 import matplotlib.pyplot as plt import numpy as np import pandas as pd -from config import DATA_METADATA, VIDEO_INFO_TSV +from config import DATA_METADATA, INVENTORY_CSV, VIDEO_INFO_TSV from detect_barrier_opening import ( SEARCH_END_S, STEP_S, WINDOW_S, per_frame_distance, sliding_mean, @@ -51,43 +52,30 @@ OUT_CSV = DATA_METADATA / "barrier_opening.csv" OUT_COLS = ["machine_name", "session_date", "session_time", "opening_s", "trim_first_s", "notes"] +DB_NAME_RE = re.compile( + r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__" +) -def parse_db_filename(db_path: Path) -> tuple[str, str, str] | None: - """Pull (date, time, machine_uuid) out of a tracking DB filename.""" - import re - m = re.match( - r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__", - db_path.name, - ) - if not m: +GRID_ROWS, GRID_COLS = 6, 10 +N_THUMBS = GRID_ROWS * GRID_COLS # 60 +COARSE_SPAN_S = SEARCH_END_S # 0..300s, ~5s spacing +FINE_SPAN_S = 12.0 # ±6s around coarse pick → ~0.2s spacing + + +def auto_suggest(db_path: Path) -> float | None: + """Median of per-ROI biggest-drop times. None if too noisy.""" + try: + conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) + except sqlite3.Error: return None - return m.group(1), m.group(2), m.group(3) - - -def load_distance_traces(db_path: Path) -> dict[int, pd.DataFrame]: - """For each ROI 1..6, return windowed-mean DF; empty if ROI missing.""" - out: dict[int, pd.DataFrame] = {} - with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn: - for roi in range(1, 7): - try: - df = pd.read_sql_query( - f"SELECT t, x, y, id FROM ROI_{roi}", conn - ) - except Exception: - out[roi] = pd.DataFrame() - continue - dist = per_frame_distance(df) - out[roi] = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S) - return out - - -def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None: - """Median of per-ROI biggest-drop times. Returns None if too noisy.""" candidates = [] - for roi, smean in traces.items(): - if len(smean) < 30: + for roi in range(1, 7): + try: + df = pd.read_sql_query(f"SELECT t, x, y, id FROM ROI_{roi}", conn) + except Exception: continue - # Find the time of the largest decrease in median(pre)–median(post). + dist = per_frame_distance(df) + smean = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S) pad = max(1, int(WINDOW_S / STEP_S)) if len(smean) < 2 * pad + 1: continue @@ -102,103 +90,200 @@ def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None: best_t = float(smean["mid_t"].iloc[i]) if best_drop > 30 and best_t is not None: candidates.append(best_t) + conn.close() if not candidates: return None return float(np.median(candidates)) -def show_picker( - db_path: Path, - machine_name: str, - session_date: str, - session_time: str, - auto_t: float | None, - initial_t: float, -) -> dict | None: - """Open the picker window. Returns a dict ready for OUT_CSV, or None to skip.""" - traces = load_distance_traces(db_path) - if all(s.empty for s in traces.values()): - print(f" ! no usable ROI traces in {db_path.name}; skipping") - return None +def grab_thumbnails(video_path: Path, target_times_s: np.ndarray, + thumb_w: int = 320) -> list[np.ndarray | None]: + """Read thumbnails at the requested timestamps via a single sequential pass. - fig, axes = plt.subplots(6, 1, figsize=(13, 12), sharex=True) - fig.suptitle( - f"{machine_name} {session_date} {session_time}\n" - f"click ↦ set opening · ENTER save · " - f"[/] ±1s · {{/}} ±5s · n skip · u unusable · r reset · q quit", - fontsize=10, + Linear-decode is much faster than seeking per-frame on H.264. We read + frames sequentially from the earliest target onward, keeping only the + ones at requested target frames. + """ + cap = cv2.VideoCapture(str(video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + if total_frames <= 0: + cap.release() + return [None] * len(target_times_s) + + target_frames = np.clip( + (target_times_s * fps).round().astype(int), 0, total_frames - 1 ) + sort_idx = np.argsort(target_frames) + sorted_targets = target_frames[sort_idx] - state: dict = {"t": float(initial_t), "auto": auto_t, "result": None} + out: list[np.ndarray | None] = [None] * len(target_times_s) + if sorted_targets.size == 0: + cap.release() + return out - def redraw(): - for ax, (roi, smean) in zip(axes, sorted(traces.items())): - ax.cla() - if smean.empty: - ax.text(0.5, 0.5, f"ROI {roi}: no data", - transform=ax.transAxes, ha="center", va="center", color="grey") - else: - ax.plot(smean["mid_t"], smean["mean_dist"], color="steelblue", lw=1.0) - ax.set_ylabel(f"ROI {roi}") - if state["auto"] is not None: - ax.axvline(state["auto"], color="orange", ls=":", lw=0.8, alpha=0.8) - ax.axvline(state["t"], color="red", lw=1.5) - ax.set_xlim(0, SEARCH_END_S) - ax.grid(True, alpha=0.3) - axes[-1].set_xlabel("time (s)") - axes[0].set_title(f"orange dotted = auto-suggested · red = current pick: {state['t']:.1f} s", - fontsize=9) - fig.canvas.draw_idle() + cap.set(cv2.CAP_PROP_POS_FRAMES, int(sorted_targets[0])) + cur_frame = int(sorted_targets[0]) + last_frame_data: np.ndarray | None = None + + scale = thumb_w / src_w if src_w > 0 else 1.0 + thumb_h = max(1, int(round(src_h * scale))) + + for ord_i, target in zip(sort_idx, sorted_targets): + while cur_frame <= target: + ret, frame = cap.read() + if not ret: + last_frame_data = None + break + last_frame_data = frame + cur_frame += 1 + if last_frame_data is not None: + small = cv2.resize(last_frame_data, (thumb_w, thumb_h), + interpolation=cv2.INTER_AREA) + out[ord_i] = cv2.cvtColor(small, cv2.COLOR_BGR2RGB) + + cap.release() + return out + + +def show_thumbnail_grid( + video_path: Path, + center_t: float, + span_s: float, + title: str, +) -> tuple[float | None, str]: + """Show a 10×6 thumbnail grid; return (clicked_time, action). + + `action` is one of: 'pick', 'skip', 'unusable', 'back', 'quit'. + `clicked_time` is None unless action == 'pick'. + """ + half = span_s / 2.0 + times = np.linspace(max(0.0, center_t - half), center_t + half, N_THUMBS) + print(f" loading {N_THUMBS} thumbnails ({times[0]:.1f}–{times[-1]:.1f}s)...", flush=True) + thumbs = grab_thumbnails(video_path, times) + + fig, axes = plt.subplots(GRID_ROWS, GRID_COLS, figsize=(20, 11)) + fig.suptitle( + f"{title}\nclick a thumbnail · n=skip · u=unusable · b=back · q=quit", + fontsize=11, + ) + state = {"time": None, "action": None} + + for ax, t, thumb in zip(axes.flat, times, thumbs): + if thumb is not None: + ax.imshow(thumb) + else: + ax.set_facecolor("black") + ax.text(0.5, 0.5, "no frame", + transform=ax.transAxes, ha="center", va="center", color="white") + # Format time as M:SS.s for readability + m, s = divmod(t, 60) + ax.set_title(f"{int(m):d}:{s:05.2f}", fontsize=8, pad=1) + ax.set_xticks([]); ax.set_yticks([]) + fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01, + wspace=0.03, hspace=0.18) def on_click(event): - if event.inaxes in axes and event.xdata is not None: - state["t"] = max(0.0, min(SEARCH_END_S, float(event.xdata))) - redraw() + if event.inaxes is None: + return + for i, ax in enumerate(axes.flat): + if ax is event.inaxes: + state["time"] = float(times[i]) + state["action"] = "pick" + plt.close(fig) + return def on_key(event): k = event.key - if k == "enter": - state["result"] = { - "machine_name": machine_name, - "session_date": session_date, - "session_time": session_time, - "opening_s": round(state["t"], 1), - "trim_first_s": 0, - "notes": "", - } - plt.close(fig) - elif k == "n": - state["result"] = "skip" - plt.close(fig) + if k == "n": + state["action"] = "skip"; plt.close(fig) elif k == "u": - state["result"] = { - "machine_name": machine_name, - "session_date": session_date, - "session_time": session_time, - "opening_s": np.nan, - "trim_first_s": 0, - "notes": "unusable", - } - plt.close(fig) - elif k == "r" and state["auto"] is not None: - state["t"] = state["auto"] - redraw() - elif k in ("[", "]"): - state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-1 if k == "[" else 1))) - redraw() - elif k in ("{", "}"): - state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-5 if k == "{" else 5))) - redraw() + state["action"] = "unusable"; plt.close(fig) + elif k == "b": + state["action"] = "back"; plt.close(fig) elif k in ("q", "escape"): - state["result"] = "quit" - plt.close(fig) + state["action"] = "quit"; plt.close(fig) fig.canvas.mpl_connect("button_press_event", on_click) fig.canvas.mpl_connect("key_press_event", on_key) - redraw() plt.show() + return state["time"], state["action"] or "skip" - return state["result"] + +def pick_for_video( + video_path: Path, + db_path: Path | None, + machine_name: str, + session_date: str, + session_time: str, +) -> dict | str | None: + """Run the two-stage thumbnail picker. Return dict, 'skip', or 'quit'.""" + auto_t = auto_suggest(db_path) if db_path else None + print(f" auto-suggest: {f'{auto_t:.1f}s' if auto_t else '(none)'}") + + # Stage 1: coarse grid centred on auto-suggest (or 150 s default). + coarse_center = auto_t if auto_t is not None else COARSE_SPAN_S / 2 + title_coarse = f"COARSE {machine_name} {session_date} {session_time} · spanning 5 min" + while True: + coarse_t, action = show_thumbnail_grid( + video_path, coarse_center, COARSE_SPAN_S, title_coarse + ) + if action == "skip": + return "skip" + if action == "unusable": + return { + "machine_name": machine_name, "session_date": session_date, + "session_time": session_time, "opening_s": np.nan, + "trim_first_s": 0, "notes": "unusable", + } + if action == "quit": + return "quit" + if action == "back": + continue # already at the top stage; redraw + if action == "pick" and coarse_t is not None: + break + + # Stage 2: fine grid around the coarse pick. + title_fine = (f"FINE {machine_name} {session_date} {session_time} " + f"· ±{FINE_SPAN_S/2:.0f} s around {coarse_t:.1f} s") + while True: + fine_t, action = show_thumbnail_grid( + video_path, coarse_t, FINE_SPAN_S, title_fine + ) + if action == "back": + return pick_for_video(video_path, db_path, machine_name, + session_date, session_time) + if action == "skip": + return "skip" + if action == "unusable": + return { + "machine_name": machine_name, "session_date": session_date, + "session_time": session_time, "opening_s": np.nan, + "trim_first_s": 0, "notes": "unusable", + } + if action == "quit": + return "quit" + if action == "pick" and fine_t is not None: + return { + "machine_name": machine_name, "session_date": session_date, + "session_time": session_time, "opening_s": round(fine_t, 1), + "trim_first_s": 0, "notes": "", + } + + +def lookup_video_path(machine_name: str, session_date: str, + session_time: str, inv: pd.DataFrame) -> Path | None: + """Find the mp4 path for (machine, date, time) in the inventory.""" + match = inv[ + (inv["machine_name"] == machine_name) + & (inv["session_date"] == session_date) + & (inv["session_time"] == session_time) + ] + if match.empty: + return None + return Path(match.iloc[0]["mp4_path"]) def main() -> None: @@ -208,7 +293,7 @@ def main() -> None: parser.add_argument("--limit", type=int, default=None, help="only process the first N videos") parser.add_argument("--db", type=Path, default=None, - help="annotate this specific DB only") + help="annotate this specific tracking DB only") args = parser.parse_args() OUT_CSV.parent.mkdir(parents=True, exist_ok=True) @@ -218,10 +303,15 @@ def main() -> None: out = pd.DataFrame(columns=OUT_COLS) done = set(zip(out["machine_name"], out["session_date"], out["session_time"])) - # Build the queue: every tracked DB referenced by the merged TSV that - # hasn't been picked yet. + if not INVENTORY_CSV.exists(): + sys.exit(f"Inventory not found at {INVENTORY_CSV}. Run build_video_inventory.py first.") + inv = pd.read_csv(INVENTORY_CSV) + + # Build the queue: every (machine, date, time) referenced by the merged + # TSV that has a tracking DB on disk and isn't yet annotated. tsv = pd.read_csv(VIDEO_INFO_TSV, sep="\t") - queue = [] + queue: list[tuple[Path, Path, str, str, str]] = [] + seen: set[tuple[str, str, str]] = set() for col in ("training_db_path", "testing_db_path"): for _, row in tsv.iterrows(): db = row[col] @@ -230,28 +320,25 @@ def main() -> None: db_path = Path(db) if not db_path.exists(): continue - parsed = parse_db_filename(db_path) - if parsed is None: + m = DB_NAME_RE.match(db_path.name) + if not m: continue - session_date, session_time, _ = parsed + session_date, session_time = m.group(1), m.group(2) key = (row["machine_name"], session_date, session_time) + if key in seen: + continue + seen.add(key) if key in done and not args.redo: continue - queue.append((db_path, row["machine_name"], session_date, session_time)) - - # Dedup (a fly may reference the same DB for both training & testing). - seen = set() - deduped = [] - for item in queue: - k = (item[1], item[2], item[3]) - if k not in seen: - seen.add(k) - deduped.append(item) - queue = deduped + video = lookup_video_path(*key, inv) + if video is None or not video.exists(): + print(f" ! no video for {key}; skipping") + continue + queue.append((db_path, video, *key)) if args.db: target = Path(args.db).resolve() - queue = [q for q in queue if Path(q[0]).resolve() == target] + queue = [q for q in queue if q[0].resolve() == target] if not queue: sys.exit(f"DB not found in queue: {args.db}") @@ -259,34 +346,26 @@ def main() -> None: queue = queue[: args.limit] if not queue: - print("Nothing to pick. All eligible DBs already have a barrier_opening row.") + print("Nothing to pick. All eligible videos already have a barrier_opening row.") return print(f"Picking barrier-opening for {len(queue)} videos.") - print("Window keys: click=set ENTER=save [/]=±1s {/}=±5s n=skip u=unusable r=reset q=quit") + print("Window keys: click=pick · n=skip · u=unusable · b=back · q=quit") saved = skipped = unusable = 0 - for i, (db, machine_name, session_date, session_time) in enumerate(queue, 1): + for i, (db, video, machine_name, session_date, session_time) in enumerate(queue, 1): prefix = f"[{i}/{len(queue)}] {machine_name} {session_date} {session_time}" print(f"\n{prefix}") - traces = load_distance_traces(db) - auto_t = auto_suggest(traces) - initial = auto_t if auto_t is not None else 60.0 - print(f" auto-suggest: " - f"{f'{auto_t:.1f}s' if auto_t is not None else '(none)'}") - - result = show_picker(db, machine_name, session_date, session_time, - auto_t, initial) + result = pick_for_video(video, db, machine_name, session_date, session_time) if result is None or result == "skip": skipped += 1 continue if result == "quit": - print(" quit requested — saving what we have and exiting") + print(" quit requested — saving and exiting") break - # Append + dedup on key + persist after each save (crash-safe). new_row = pd.DataFrame([result]) out = pd.concat([ out[~((out.machine_name == result["machine_name"]) & @@ -297,6 +376,7 @@ def main() -> None: out[OUT_COLS].to_csv(OUT_CSV, index=False) if pd.isna(result["opening_s"]): unusable += 1 + print(" saved as unusable") else: saved += 1 print(f" saved opening_s = {result['opening_s']} s")