diff --git a/data/metadata/barrier_opening.csv b/data/metadata/barrier_opening.csv new file mode 100644 index 0000000..f1e2123 --- /dev/null +++ b/data/metadata/barrier_opening.csv @@ -0,0 +1,6 @@ +machine_name,session_date,session_time,opening_s,trim_first_s,notes +ETHOSCOPE_076,2025-07-15,16-03-10,52,0,hand-annotated 2025-07-15 batch +ETHOSCOPE_076,2025-07-15,16-31-34,94,69,first ~66s misframed (arena partly out of frame) +ETHOSCOPE_145,2025-07-15,16-03-27,42,0,hand-annotated 2025-07-15 batch +ETHOSCOPE_145,2025-07-15,16-31-41,89,69,first ~60s misframed (arena partly out of frame) +ETHOSCOPE_268,2025-07-15,16-32-05,75,0,hand-annotated 2025-07-15 batch diff --git a/scripts/pick_barrier.py b/scripts/pick_barrier.py new file mode 100644 index 0000000..cb24f47 --- /dev/null +++ b/scripts/pick_barrier.py @@ -0,0 +1,309 @@ +"""Interactive picker for barrier-opening time per tracked video. + +Loops through tracked DBs that don't yet have a barrier-opening +annotation. For each, plots the windowed mean inter-fly distance for +all 6 ROIs over the first 5 minutes and lets the analyst click the +moment the barrier opens (when most flies start coming close together). +The auto-detector's best-effort guess is shown as a starting position. + +Output: data/metadata/barrier_opening.csv with columns + machine_name, session_date, session_time, opening_s, trim_first_s, notes + +`opening_s` is the moment of barrier opening, measured from the start +of the recording (NOT the start of any trimmed copy). `trim_first_s` +is an optional annotation for videos with a misframed start that +should be ignored by analysis (defaults to 0). + +Window keys: + click place the opening cursor at that time + ENTER save and advance + [, ] shift cursor by 1 s left / right + {, } shift cursor by 5 s left / right + n skip this video for THIS run (no row written) + u mark this video unusable (writes opening_s = NaN, notes = "unusable") + r reset cursor to the auto-detected position + q / ESC save+quit + +Usage: + python pick_barrier.py + python pick_barrier.py --redo # re-pick videos that already have a row + python pick_barrier.py --limit 10 +""" + +from __future__ import annotations + +import argparse +import sqlite3 +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from config import DATA_METADATA, VIDEO_INFO_TSV +from detect_barrier_opening import ( + SEARCH_END_S, STEP_S, WINDOW_S, + per_frame_distance, sliding_mean, +) + +OUT_CSV = DATA_METADATA / "barrier_opening.csv" +OUT_COLS = ["machine_name", "session_date", "session_time", + "opening_s", "trim_first_s", "notes"] + + +def parse_db_filename(db_path: Path) -> tuple[str, str, str] | None: + """Pull (date, time, machine_uuid) out of a tracking DB filename.""" + import re + m = re.match( + r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__", + db_path.name, + ) + if not m: + return None + return m.group(1), m.group(2), m.group(3) + + +def load_distance_traces(db_path: Path) -> dict[int, pd.DataFrame]: + """For each ROI 1..6, return windowed-mean DF; empty if ROI missing.""" + out: dict[int, pd.DataFrame] = {} + with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn: + for roi in range(1, 7): + try: + df = pd.read_sql_query( + f"SELECT t, x, y, id FROM ROI_{roi}", conn + ) + except Exception: + out[roi] = pd.DataFrame() + continue + dist = per_frame_distance(df) + out[roi] = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S) + return out + + +def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None: + """Median of per-ROI biggest-drop times. Returns None if too noisy.""" + candidates = [] + for roi, smean in traces.items(): + if len(smean) < 30: + continue + # Find the time of the largest decrease in median(pre)–median(post). + pad = max(1, int(WINDOW_S / STEP_S)) + if len(smean) < 2 * pad + 1: + continue + best_drop = -np.inf + best_t = None + for i in range(pad, len(smean) - pad): + pre = smean["mean_dist"].iloc[:i].median() + post = smean["mean_dist"].iloc[i:].median() + drop = pre - post + if drop > best_drop: + best_drop = drop + best_t = float(smean["mid_t"].iloc[i]) + if best_drop > 30 and best_t is not None: + candidates.append(best_t) + if not candidates: + return None + return float(np.median(candidates)) + + +def show_picker( + db_path: Path, + machine_name: str, + session_date: str, + session_time: str, + auto_t: float | None, + initial_t: float, +) -> dict | None: + """Open the picker window. Returns a dict ready for OUT_CSV, or None to skip.""" + traces = load_distance_traces(db_path) + if all(s.empty for s in traces.values()): + print(f" ! no usable ROI traces in {db_path.name}; skipping") + return None + + fig, axes = plt.subplots(6, 1, figsize=(13, 12), sharex=True) + fig.suptitle( + f"{machine_name} {session_date} {session_time}\n" + f"click ↦ set opening · ENTER save · " + f"[/] ±1s · {{/}} ±5s · n skip · u unusable · r reset · q quit", + fontsize=10, + ) + + state: dict = {"t": float(initial_t), "auto": auto_t, "result": None} + + def redraw(): + for ax, (roi, smean) in zip(axes, sorted(traces.items())): + ax.cla() + if smean.empty: + ax.text(0.5, 0.5, f"ROI {roi}: no data", + transform=ax.transAxes, ha="center", va="center", color="grey") + else: + ax.plot(smean["mid_t"], smean["mean_dist"], color="steelblue", lw=1.0) + ax.set_ylabel(f"ROI {roi}") + if state["auto"] is not None: + ax.axvline(state["auto"], color="orange", ls=":", lw=0.8, alpha=0.8) + ax.axvline(state["t"], color="red", lw=1.5) + ax.set_xlim(0, SEARCH_END_S) + ax.grid(True, alpha=0.3) + axes[-1].set_xlabel("time (s)") + axes[0].set_title(f"orange dotted = auto-suggested · red = current pick: {state['t']:.1f} s", + fontsize=9) + fig.canvas.draw_idle() + + def on_click(event): + if event.inaxes in axes and event.xdata is not None: + state["t"] = max(0.0, min(SEARCH_END_S, float(event.xdata))) + redraw() + + def on_key(event): + k = event.key + if k == "enter": + state["result"] = { + "machine_name": machine_name, + "session_date": session_date, + "session_time": session_time, + "opening_s": round(state["t"], 1), + "trim_first_s": 0, + "notes": "", + } + plt.close(fig) + elif k == "n": + state["result"] = "skip" + plt.close(fig) + elif k == "u": + state["result"] = { + "machine_name": machine_name, + "session_date": session_date, + "session_time": session_time, + "opening_s": np.nan, + "trim_first_s": 0, + "notes": "unusable", + } + plt.close(fig) + elif k == "r" and state["auto"] is not None: + state["t"] = state["auto"] + redraw() + elif k in ("[", "]"): + state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-1 if k == "[" else 1))) + redraw() + elif k in ("{", "}"): + state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-5 if k == "{" else 5))) + redraw() + elif k in ("q", "escape"): + state["result"] = "quit" + plt.close(fig) + + fig.canvas.mpl_connect("button_press_event", on_click) + fig.canvas.mpl_connect("key_press_event", on_key) + redraw() + plt.show() + + return state["result"] + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--redo", action="store_true", + help="re-pick videos that already have a row in the output CSV") + parser.add_argument("--limit", type=int, default=None, + help="only process the first N videos") + parser.add_argument("--db", type=Path, default=None, + help="annotate this specific DB only") + args = parser.parse_args() + + OUT_CSV.parent.mkdir(parents=True, exist_ok=True) + if OUT_CSV.exists(): + out = pd.read_csv(OUT_CSV) + else: + out = pd.DataFrame(columns=OUT_COLS) + done = set(zip(out["machine_name"], out["session_date"], out["session_time"])) + + # Build the queue: every tracked DB referenced by the merged TSV that + # hasn't been picked yet. + tsv = pd.read_csv(VIDEO_INFO_TSV, sep="\t") + queue = [] + for col in ("training_db_path", "testing_db_path"): + for _, row in tsv.iterrows(): + db = row[col] + if not isinstance(db, str) or not db: + continue + db_path = Path(db) + if not db_path.exists(): + continue + parsed = parse_db_filename(db_path) + if parsed is None: + continue + session_date, session_time, _ = parsed + key = (row["machine_name"], session_date, session_time) + if key in done and not args.redo: + continue + queue.append((db_path, row["machine_name"], session_date, session_time)) + + # Dedup (a fly may reference the same DB for both training & testing). + seen = set() + deduped = [] + for item in queue: + k = (item[1], item[2], item[3]) + if k not in seen: + seen.add(k) + deduped.append(item) + queue = deduped + + if args.db: + target = Path(args.db).resolve() + queue = [q for q in queue if Path(q[0]).resolve() == target] + if not queue: + sys.exit(f"DB not found in queue: {args.db}") + + if args.limit: + queue = queue[: args.limit] + + if not queue: + print("Nothing to pick. All eligible DBs already have a barrier_opening row.") + return + + print(f"Picking barrier-opening for {len(queue)} videos.") + print("Window keys: click=set ENTER=save [/]=±1s {/}=±5s n=skip u=unusable r=reset q=quit") + + saved = skipped = unusable = 0 + for i, (db, machine_name, session_date, session_time) in enumerate(queue, 1): + prefix = f"[{i}/{len(queue)}] {machine_name} {session_date} {session_time}" + print(f"\n{prefix}") + + traces = load_distance_traces(db) + auto_t = auto_suggest(traces) + initial = auto_t if auto_t is not None else 60.0 + print(f" auto-suggest: " + f"{f'{auto_t:.1f}s' if auto_t is not None else '(none)'}") + + result = show_picker(db, machine_name, session_date, session_time, + auto_t, initial) + + if result is None or result == "skip": + skipped += 1 + continue + if result == "quit": + print(" quit requested — saving what we have and exiting") + break + + # Append + dedup on key + persist after each save (crash-safe). + new_row = pd.DataFrame([result]) + out = pd.concat([ + out[~((out.machine_name == result["machine_name"]) & + (out.session_date == result["session_date"]) & + (out.session_time == result["session_time"]))], + new_row, + ], ignore_index=True) + out[OUT_COLS].to_csv(OUT_CSV, index=False) + if pd.isna(result["opening_s"]): + unusable += 1 + else: + saved += 1 + print(f" saved opening_s = {result['opening_s']} s") + + print(f"\nDone: {saved} saved, {unusable} unusable, {skipped} skipped.") + print(f" → {OUT_CSV}") + + +if __name__ == "__main__": + main()