Replace pick_barrier.py with thumbnail-grid UX

Old version showed inter-fly distance plots and asked the analyst to
click a timeline. The new version reads frames directly from the .mp4
and shows a 10×6 grid of timestamped thumbnails — the analyst just
clicks the frame where the barrier opens.

Two-stage refinement:
  - Coarse grid: 60 thumbs spanning the 5-min search window at ~5 s
    spacing. Pick the rough moment.
  - Fine grid: 60 thumbs at 0.2 s spacing centred on the coarse pick.
    Pick the exact frame.

Auto-detector still feeds the starting position. Sequential video
decode (one cv2 pass through the relevant range) instead of seek-per-
frame, so each grid loads in a few seconds.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Giorgio Gilestro 2026-05-01 12:01:34 +01:00
parent b46c4ac1ba
commit e8c7f23d4d

View file

@ -1,47 +1,48 @@
"""Interactive picker for barrier-opening time per tracked video.
"""Interactive picker for barrier-opening time, frame-by-frame thumbnail style.
Loops through tracked DBs that don't yet have a barrier-opening
annotation. For each, plots the windowed mean inter-fly distance for
all 6 ROIs over the first 5 minutes and lets the analyst click the
moment the barrier opens (when most flies start coming close together).
The auto-detector's best-effort guess is shown as a starting position.
For each video that doesn't yet have a barrier-opening annotation, show a
10x6 grid of timestamped thumbnails extracted directly from the .mp4.
The analyst clicks the thumbnail at (or just after) the moment the
barrier opens; the picker then refines with a second tighter grid for
sub-second precision.
Two-stage flow per video:
1. Coarse grid: 60 thumbs spanning the 5-min search window (5 s spacing).
Click pick that 5 s slot.
2. Fine grid: 60 thumbs spanning ±6 s of the coarse pick (0.2 s spacing).
Click final answer with 0.2 s precision.
Output: data/metadata/barrier_opening.csv with columns
machine_name, session_date, session_time, opening_s, trim_first_s, notes
`opening_s` is the moment of barrier opening, measured from the start
of the recording (NOT the start of any trimmed copy). `trim_first_s`
is an optional annotation for videos with a misframed start that
should be ignored by analysis (defaults to 0).
Window keys:
click place the opening cursor at that time
ENTER save and advance
[, ] shift cursor by 1 s left / right
{, } shift cursor by 5 s left / right
n skip this video for THIS run (no row written)
u mark this video unusable (writes opening_s = NaN, notes = "unusable")
r reset cursor to the auto-detected position
click select thumbnail at that timestamp
n skip this video for THIS run
u mark unusable (opening_s = NaN)
b back to coarse grid (after seeing fine grid)
q / ESC save+quit
Usage:
python pick_barrier.py
python pick_barrier.py --redo # re-pick videos that already have a row
python pick_barrier.py --redo
python pick_barrier.py --limit 10
python pick_barrier.py --db /path/to/specific_tracking.db
"""
from __future__ import annotations
import argparse
import re
import sqlite3
import sys
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from config import DATA_METADATA, VIDEO_INFO_TSV
from config import DATA_METADATA, INVENTORY_CSV, VIDEO_INFO_TSV
from detect_barrier_opening import (
SEARCH_END_S, STEP_S, WINDOW_S,
per_frame_distance, sliding_mean,
@ -51,43 +52,30 @@ OUT_CSV = DATA_METADATA / "barrier_opening.csv"
OUT_COLS = ["machine_name", "session_date", "session_time",
"opening_s", "trim_first_s", "notes"]
DB_NAME_RE = re.compile(
r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__"
)
def parse_db_filename(db_path: Path) -> tuple[str, str, str] | None:
"""Pull (date, time, machine_uuid) out of a tracking DB filename."""
import re
m = re.match(
r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__",
db_path.name,
)
if not m:
GRID_ROWS, GRID_COLS = 6, 10
N_THUMBS = GRID_ROWS * GRID_COLS # 60
COARSE_SPAN_S = SEARCH_END_S # 0..300s, ~5s spacing
FINE_SPAN_S = 12.0 # ±6s around coarse pick → ~0.2s spacing
def auto_suggest(db_path: Path) -> float | None:
"""Median of per-ROI biggest-drop times. None if too noisy."""
try:
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
except sqlite3.Error:
return None
return m.group(1), m.group(2), m.group(3)
def load_distance_traces(db_path: Path) -> dict[int, pd.DataFrame]:
"""For each ROI 1..6, return windowed-mean DF; empty if ROI missing."""
out: dict[int, pd.DataFrame] = {}
with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn:
candidates = []
for roi in range(1, 7):
try:
df = pd.read_sql_query(
f"SELECT t, x, y, id FROM ROI_{roi}", conn
)
df = pd.read_sql_query(f"SELECT t, x, y, id FROM ROI_{roi}", conn)
except Exception:
out[roi] = pd.DataFrame()
continue
dist = per_frame_distance(df)
out[roi] = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
return out
def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None:
"""Median of per-ROI biggest-drop times. Returns None if too noisy."""
candidates = []
for roi, smean in traces.items():
if len(smean) < 30:
continue
# Find the time of the largest decrease in median(pre)median(post).
smean = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
pad = max(1, int(WINDOW_S / STEP_S))
if len(smean) < 2 * pad + 1:
continue
@ -102,103 +90,200 @@ def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None:
best_t = float(smean["mid_t"].iloc[i])
if best_drop > 30 and best_t is not None:
candidates.append(best_t)
conn.close()
if not candidates:
return None
return float(np.median(candidates))
def show_picker(
db_path: Path,
machine_name: str,
session_date: str,
session_time: str,
auto_t: float | None,
initial_t: float,
) -> dict | None:
"""Open the picker window. Returns a dict ready for OUT_CSV, or None to skip."""
traces = load_distance_traces(db_path)
if all(s.empty for s in traces.values()):
print(f" ! no usable ROI traces in {db_path.name}; skipping")
return None
def grab_thumbnails(video_path: Path, target_times_s: np.ndarray,
thumb_w: int = 320) -> list[np.ndarray | None]:
"""Read thumbnails at the requested timestamps via a single sequential pass.
fig, axes = plt.subplots(6, 1, figsize=(13, 12), sharex=True)
fig.suptitle(
f"{machine_name} {session_date} {session_time}\n"
f"click ↦ set opening · ENTER save · "
f"[/] ±1s · {{/}} ±5s · n skip · u unusable · r reset · q quit",
fontsize=10,
Linear-decode is much faster than seeking per-frame on H.264. We read
frames sequentially from the earliest target onward, keeping only the
ones at requested target frames.
"""
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if total_frames <= 0:
cap.release()
return [None] * len(target_times_s)
target_frames = np.clip(
(target_times_s * fps).round().astype(int), 0, total_frames - 1
)
sort_idx = np.argsort(target_frames)
sorted_targets = target_frames[sort_idx]
state: dict = {"t": float(initial_t), "auto": auto_t, "result": None}
out: list[np.ndarray | None] = [None] * len(target_times_s)
if sorted_targets.size == 0:
cap.release()
return out
def redraw():
for ax, (roi, smean) in zip(axes, sorted(traces.items())):
ax.cla()
if smean.empty:
ax.text(0.5, 0.5, f"ROI {roi}: no data",
transform=ax.transAxes, ha="center", va="center", color="grey")
cap.set(cv2.CAP_PROP_POS_FRAMES, int(sorted_targets[0]))
cur_frame = int(sorted_targets[0])
last_frame_data: np.ndarray | None = None
scale = thumb_w / src_w if src_w > 0 else 1.0
thumb_h = max(1, int(round(src_h * scale)))
for ord_i, target in zip(sort_idx, sorted_targets):
while cur_frame <= target:
ret, frame = cap.read()
if not ret:
last_frame_data = None
break
last_frame_data = frame
cur_frame += 1
if last_frame_data is not None:
small = cv2.resize(last_frame_data, (thumb_w, thumb_h),
interpolation=cv2.INTER_AREA)
out[ord_i] = cv2.cvtColor(small, cv2.COLOR_BGR2RGB)
cap.release()
return out
def show_thumbnail_grid(
video_path: Path,
center_t: float,
span_s: float,
title: str,
) -> tuple[float | None, str]:
"""Show a 10×6 thumbnail grid; return (clicked_time, action).
`action` is one of: 'pick', 'skip', 'unusable', 'back', 'quit'.
`clicked_time` is None unless action == 'pick'.
"""
half = span_s / 2.0
times = np.linspace(max(0.0, center_t - half), center_t + half, N_THUMBS)
print(f" loading {N_THUMBS} thumbnails ({times[0]:.1f}{times[-1]:.1f}s)...", flush=True)
thumbs = grab_thumbnails(video_path, times)
fig, axes = plt.subplots(GRID_ROWS, GRID_COLS, figsize=(20, 11))
fig.suptitle(
f"{title}\nclick a thumbnail · n=skip · u=unusable · b=back · q=quit",
fontsize=11,
)
state = {"time": None, "action": None}
for ax, t, thumb in zip(axes.flat, times, thumbs):
if thumb is not None:
ax.imshow(thumb)
else:
ax.plot(smean["mid_t"], smean["mean_dist"], color="steelblue", lw=1.0)
ax.set_ylabel(f"ROI {roi}")
if state["auto"] is not None:
ax.axvline(state["auto"], color="orange", ls=":", lw=0.8, alpha=0.8)
ax.axvline(state["t"], color="red", lw=1.5)
ax.set_xlim(0, SEARCH_END_S)
ax.grid(True, alpha=0.3)
axes[-1].set_xlabel("time (s)")
axes[0].set_title(f"orange dotted = auto-suggested · red = current pick: {state['t']:.1f} s",
fontsize=9)
fig.canvas.draw_idle()
ax.set_facecolor("black")
ax.text(0.5, 0.5, "no frame",
transform=ax.transAxes, ha="center", va="center", color="white")
# Format time as M:SS.s for readability
m, s = divmod(t, 60)
ax.set_title(f"{int(m):d}:{s:05.2f}", fontsize=8, pad=1)
ax.set_xticks([]); ax.set_yticks([])
fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.01,
wspace=0.03, hspace=0.18)
def on_click(event):
if event.inaxes in axes and event.xdata is not None:
state["t"] = max(0.0, min(SEARCH_END_S, float(event.xdata)))
redraw()
if event.inaxes is None:
return
for i, ax in enumerate(axes.flat):
if ax is event.inaxes:
state["time"] = float(times[i])
state["action"] = "pick"
plt.close(fig)
return
def on_key(event):
k = event.key
if k == "enter":
state["result"] = {
"machine_name": machine_name,
"session_date": session_date,
"session_time": session_time,
"opening_s": round(state["t"], 1),
"trim_first_s": 0,
"notes": "",
}
plt.close(fig)
elif k == "n":
state["result"] = "skip"
plt.close(fig)
if k == "n":
state["action"] = "skip"; plt.close(fig)
elif k == "u":
state["result"] = {
"machine_name": machine_name,
"session_date": session_date,
"session_time": session_time,
"opening_s": np.nan,
"trim_first_s": 0,
"notes": "unusable",
}
plt.close(fig)
elif k == "r" and state["auto"] is not None:
state["t"] = state["auto"]
redraw()
elif k in ("[", "]"):
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-1 if k == "[" else 1)))
redraw()
elif k in ("{", "}"):
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-5 if k == "{" else 5)))
redraw()
state["action"] = "unusable"; plt.close(fig)
elif k == "b":
state["action"] = "back"; plt.close(fig)
elif k in ("q", "escape"):
state["result"] = "quit"
plt.close(fig)
state["action"] = "quit"; plt.close(fig)
fig.canvas.mpl_connect("button_press_event", on_click)
fig.canvas.mpl_connect("key_press_event", on_key)
redraw()
plt.show()
return state["time"], state["action"] or "skip"
return state["result"]
def pick_for_video(
video_path: Path,
db_path: Path | None,
machine_name: str,
session_date: str,
session_time: str,
) -> dict | str | None:
"""Run the two-stage thumbnail picker. Return dict, 'skip', or 'quit'."""
auto_t = auto_suggest(db_path) if db_path else None
print(f" auto-suggest: {f'{auto_t:.1f}s' if auto_t else '(none)'}")
# Stage 1: coarse grid centred on auto-suggest (or 150 s default).
coarse_center = auto_t if auto_t is not None else COARSE_SPAN_S / 2
title_coarse = f"COARSE {machine_name} {session_date} {session_time} · spanning 5 min"
while True:
coarse_t, action = show_thumbnail_grid(
video_path, coarse_center, COARSE_SPAN_S, title_coarse
)
if action == "skip":
return "skip"
if action == "unusable":
return {
"machine_name": machine_name, "session_date": session_date,
"session_time": session_time, "opening_s": np.nan,
"trim_first_s": 0, "notes": "unusable",
}
if action == "quit":
return "quit"
if action == "back":
continue # already at the top stage; redraw
if action == "pick" and coarse_t is not None:
break
# Stage 2: fine grid around the coarse pick.
title_fine = (f"FINE {machine_name} {session_date} {session_time} "
f"· ±{FINE_SPAN_S/2:.0f} s around {coarse_t:.1f} s")
while True:
fine_t, action = show_thumbnail_grid(
video_path, coarse_t, FINE_SPAN_S, title_fine
)
if action == "back":
return pick_for_video(video_path, db_path, machine_name,
session_date, session_time)
if action == "skip":
return "skip"
if action == "unusable":
return {
"machine_name": machine_name, "session_date": session_date,
"session_time": session_time, "opening_s": np.nan,
"trim_first_s": 0, "notes": "unusable",
}
if action == "quit":
return "quit"
if action == "pick" and fine_t is not None:
return {
"machine_name": machine_name, "session_date": session_date,
"session_time": session_time, "opening_s": round(fine_t, 1),
"trim_first_s": 0, "notes": "",
}
def lookup_video_path(machine_name: str, session_date: str,
session_time: str, inv: pd.DataFrame) -> Path | None:
"""Find the mp4 path for (machine, date, time) in the inventory."""
match = inv[
(inv["machine_name"] == machine_name)
& (inv["session_date"] == session_date)
& (inv["session_time"] == session_time)
]
if match.empty:
return None
return Path(match.iloc[0]["mp4_path"])
def main() -> None:
@ -208,7 +293,7 @@ def main() -> None:
parser.add_argument("--limit", type=int, default=None,
help="only process the first N videos")
parser.add_argument("--db", type=Path, default=None,
help="annotate this specific DB only")
help="annotate this specific tracking DB only")
args = parser.parse_args()
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
@ -218,10 +303,15 @@ def main() -> None:
out = pd.DataFrame(columns=OUT_COLS)
done = set(zip(out["machine_name"], out["session_date"], out["session_time"]))
# Build the queue: every tracked DB referenced by the merged TSV that
# hasn't been picked yet.
if not INVENTORY_CSV.exists():
sys.exit(f"Inventory not found at {INVENTORY_CSV}. Run build_video_inventory.py first.")
inv = pd.read_csv(INVENTORY_CSV)
# Build the queue: every (machine, date, time) referenced by the merged
# TSV that has a tracking DB on disk and isn't yet annotated.
tsv = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
queue = []
queue: list[tuple[Path, Path, str, str, str]] = []
seen: set[tuple[str, str, str]] = set()
for col in ("training_db_path", "testing_db_path"):
for _, row in tsv.iterrows():
db = row[col]
@ -230,28 +320,25 @@ def main() -> None:
db_path = Path(db)
if not db_path.exists():
continue
parsed = parse_db_filename(db_path)
if parsed is None:
m = DB_NAME_RE.match(db_path.name)
if not m:
continue
session_date, session_time, _ = parsed
session_date, session_time = m.group(1), m.group(2)
key = (row["machine_name"], session_date, session_time)
if key in seen:
continue
seen.add(key)
if key in done and not args.redo:
continue
queue.append((db_path, row["machine_name"], session_date, session_time))
# Dedup (a fly may reference the same DB for both training & testing).
seen = set()
deduped = []
for item in queue:
k = (item[1], item[2], item[3])
if k not in seen:
seen.add(k)
deduped.append(item)
queue = deduped
video = lookup_video_path(*key, inv)
if video is None or not video.exists():
print(f" ! no video for {key}; skipping")
continue
queue.append((db_path, video, *key))
if args.db:
target = Path(args.db).resolve()
queue = [q for q in queue if Path(q[0]).resolve() == target]
queue = [q for q in queue if q[0].resolve() == target]
if not queue:
sys.exit(f"DB not found in queue: {args.db}")
@ -259,34 +346,26 @@ def main() -> None:
queue = queue[: args.limit]
if not queue:
print("Nothing to pick. All eligible DBs already have a barrier_opening row.")
print("Nothing to pick. All eligible videos already have a barrier_opening row.")
return
print(f"Picking barrier-opening for {len(queue)} videos.")
print("Window keys: click=set ENTER=save [/]=±1s {/}=±5s n=skip u=unusable r=reset q=quit")
print("Window keys: click=pick · n=skip · u=unusable · b=back · q=quit")
saved = skipped = unusable = 0
for i, (db, machine_name, session_date, session_time) in enumerate(queue, 1):
for i, (db, video, machine_name, session_date, session_time) in enumerate(queue, 1):
prefix = f"[{i}/{len(queue)}] {machine_name} {session_date} {session_time}"
print(f"\n{prefix}")
traces = load_distance_traces(db)
auto_t = auto_suggest(traces)
initial = auto_t if auto_t is not None else 60.0
print(f" auto-suggest: "
f"{f'{auto_t:.1f}s' if auto_t is not None else '(none)'}")
result = show_picker(db, machine_name, session_date, session_time,
auto_t, initial)
result = pick_for_video(video, db, machine_name, session_date, session_time)
if result is None or result == "skip":
skipped += 1
continue
if result == "quit":
print(" quit requested — saving what we have and exiting")
print(" quit requested — saving and exiting")
break
# Append + dedup on key + persist after each save (crash-safe).
new_row = pd.DataFrame([result])
out = pd.concat([
out[~((out.machine_name == result["machine_name"]) &
@ -297,6 +376,7 @@ def main() -> None:
out[OUT_COLS].to_csv(OUT_CSV, index=False)
if pd.isna(result["opening_s"]):
unusable += 1
print(" saved as unusable")
else:
saved += 1
print(f" saved opening_s = {result['opening_s']} s")