pick_barrier.py loops over every tracked DB referenced by the merged
TSV, plots windowed mean inter-fly distance for all 6 ROIs in a single
figure, and lets the analyst click the moment the barrier opens. Saves
to data/metadata/barrier_opening.csv after each pick (crash-safe).
Auto-detector best-effort guess shown as orange dotted line — the
analyst always has the final say.
Output schema:
machine_name, session_date, session_time, opening_s, trim_first_s, notes
`trim_first_s` lets us record misframed starts so downstream code can
ignore the affected window. The 5 2025-07-15 entries are seeded from
the original legacy CSV so they're not re-picked.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
309 lines
11 KiB
Python
309 lines
11 KiB
Python
"""Interactive picker for barrier-opening time per tracked video.
|
||
|
||
Loops through tracked DBs that don't yet have a barrier-opening
|
||
annotation. For each, plots the windowed mean inter-fly distance for
|
||
all 6 ROIs over the first 5 minutes and lets the analyst click the
|
||
moment the barrier opens (when most flies start coming close together).
|
||
The auto-detector's best-effort guess is shown as a starting position.
|
||
|
||
Output: data/metadata/barrier_opening.csv with columns
|
||
machine_name, session_date, session_time, opening_s, trim_first_s, notes
|
||
|
||
`opening_s` is the moment of barrier opening, measured from the start
|
||
of the recording (NOT the start of any trimmed copy). `trim_first_s`
|
||
is an optional annotation for videos with a misframed start that
|
||
should be ignored by analysis (defaults to 0).
|
||
|
||
Window keys:
|
||
click place the opening cursor at that time
|
||
ENTER save and advance
|
||
[, ] shift cursor by 1 s left / right
|
||
{, } shift cursor by 5 s left / right
|
||
n skip this video for THIS run (no row written)
|
||
u mark this video unusable (writes opening_s = NaN, notes = "unusable")
|
||
r reset cursor to the auto-detected position
|
||
q / ESC save+quit
|
||
|
||
Usage:
|
||
python pick_barrier.py
|
||
python pick_barrier.py --redo # re-pick videos that already have a row
|
||
python pick_barrier.py --limit 10
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import sqlite3
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
from config import DATA_METADATA, VIDEO_INFO_TSV
|
||
from detect_barrier_opening import (
|
||
SEARCH_END_S, STEP_S, WINDOW_S,
|
||
per_frame_distance, sliding_mean,
|
||
)
|
||
|
||
OUT_CSV = DATA_METADATA / "barrier_opening.csv"
|
||
OUT_COLS = ["machine_name", "session_date", "session_time",
|
||
"opening_s", "trim_first_s", "notes"]
|
||
|
||
|
||
def parse_db_filename(db_path: Path) -> tuple[str, str, str] | None:
|
||
"""Pull (date, time, machine_uuid) out of a tracking DB filename."""
|
||
import re
|
||
m = re.match(
|
||
r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__",
|
||
db_path.name,
|
||
)
|
||
if not m:
|
||
return None
|
||
return m.group(1), m.group(2), m.group(3)
|
||
|
||
|
||
def load_distance_traces(db_path: Path) -> dict[int, pd.DataFrame]:
|
||
"""For each ROI 1..6, return windowed-mean DF; empty if ROI missing."""
|
||
out: dict[int, pd.DataFrame] = {}
|
||
with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn:
|
||
for roi in range(1, 7):
|
||
try:
|
||
df = pd.read_sql_query(
|
||
f"SELECT t, x, y, id FROM ROI_{roi}", conn
|
||
)
|
||
except Exception:
|
||
out[roi] = pd.DataFrame()
|
||
continue
|
||
dist = per_frame_distance(df)
|
||
out[roi] = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
|
||
return out
|
||
|
||
|
||
def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None:
|
||
"""Median of per-ROI biggest-drop times. Returns None if too noisy."""
|
||
candidates = []
|
||
for roi, smean in traces.items():
|
||
if len(smean) < 30:
|
||
continue
|
||
# Find the time of the largest decrease in median(pre)–median(post).
|
||
pad = max(1, int(WINDOW_S / STEP_S))
|
||
if len(smean) < 2 * pad + 1:
|
||
continue
|
||
best_drop = -np.inf
|
||
best_t = None
|
||
for i in range(pad, len(smean) - pad):
|
||
pre = smean["mean_dist"].iloc[:i].median()
|
||
post = smean["mean_dist"].iloc[i:].median()
|
||
drop = pre - post
|
||
if drop > best_drop:
|
||
best_drop = drop
|
||
best_t = float(smean["mid_t"].iloc[i])
|
||
if best_drop > 30 and best_t is not None:
|
||
candidates.append(best_t)
|
||
if not candidates:
|
||
return None
|
||
return float(np.median(candidates))
|
||
|
||
|
||
def show_picker(
|
||
db_path: Path,
|
||
machine_name: str,
|
||
session_date: str,
|
||
session_time: str,
|
||
auto_t: float | None,
|
||
initial_t: float,
|
||
) -> dict | None:
|
||
"""Open the picker window. Returns a dict ready for OUT_CSV, or None to skip."""
|
||
traces = load_distance_traces(db_path)
|
||
if all(s.empty for s in traces.values()):
|
||
print(f" ! no usable ROI traces in {db_path.name}; skipping")
|
||
return None
|
||
|
||
fig, axes = plt.subplots(6, 1, figsize=(13, 12), sharex=True)
|
||
fig.suptitle(
|
||
f"{machine_name} {session_date} {session_time}\n"
|
||
f"click ↦ set opening · ENTER save · "
|
||
f"[/] ±1s · {{/}} ±5s · n skip · u unusable · r reset · q quit",
|
||
fontsize=10,
|
||
)
|
||
|
||
state: dict = {"t": float(initial_t), "auto": auto_t, "result": None}
|
||
|
||
def redraw():
|
||
for ax, (roi, smean) in zip(axes, sorted(traces.items())):
|
||
ax.cla()
|
||
if smean.empty:
|
||
ax.text(0.5, 0.5, f"ROI {roi}: no data",
|
||
transform=ax.transAxes, ha="center", va="center", color="grey")
|
||
else:
|
||
ax.plot(smean["mid_t"], smean["mean_dist"], color="steelblue", lw=1.0)
|
||
ax.set_ylabel(f"ROI {roi}")
|
||
if state["auto"] is not None:
|
||
ax.axvline(state["auto"], color="orange", ls=":", lw=0.8, alpha=0.8)
|
||
ax.axvline(state["t"], color="red", lw=1.5)
|
||
ax.set_xlim(0, SEARCH_END_S)
|
||
ax.grid(True, alpha=0.3)
|
||
axes[-1].set_xlabel("time (s)")
|
||
axes[0].set_title(f"orange dotted = auto-suggested · red = current pick: {state['t']:.1f} s",
|
||
fontsize=9)
|
||
fig.canvas.draw_idle()
|
||
|
||
def on_click(event):
|
||
if event.inaxes in axes and event.xdata is not None:
|
||
state["t"] = max(0.0, min(SEARCH_END_S, float(event.xdata)))
|
||
redraw()
|
||
|
||
def on_key(event):
|
||
k = event.key
|
||
if k == "enter":
|
||
state["result"] = {
|
||
"machine_name": machine_name,
|
||
"session_date": session_date,
|
||
"session_time": session_time,
|
||
"opening_s": round(state["t"], 1),
|
||
"trim_first_s": 0,
|
||
"notes": "",
|
||
}
|
||
plt.close(fig)
|
||
elif k == "n":
|
||
state["result"] = "skip"
|
||
plt.close(fig)
|
||
elif k == "u":
|
||
state["result"] = {
|
||
"machine_name": machine_name,
|
||
"session_date": session_date,
|
||
"session_time": session_time,
|
||
"opening_s": np.nan,
|
||
"trim_first_s": 0,
|
||
"notes": "unusable",
|
||
}
|
||
plt.close(fig)
|
||
elif k == "r" and state["auto"] is not None:
|
||
state["t"] = state["auto"]
|
||
redraw()
|
||
elif k in ("[", "]"):
|
||
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-1 if k == "[" else 1)))
|
||
redraw()
|
||
elif k in ("{", "}"):
|
||
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-5 if k == "{" else 5)))
|
||
redraw()
|
||
elif k in ("q", "escape"):
|
||
state["result"] = "quit"
|
||
plt.close(fig)
|
||
|
||
fig.canvas.mpl_connect("button_press_event", on_click)
|
||
fig.canvas.mpl_connect("key_press_event", on_key)
|
||
redraw()
|
||
plt.show()
|
||
|
||
return state["result"]
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description=__doc__)
|
||
parser.add_argument("--redo", action="store_true",
|
||
help="re-pick videos that already have a row in the output CSV")
|
||
parser.add_argument("--limit", type=int, default=None,
|
||
help="only process the first N videos")
|
||
parser.add_argument("--db", type=Path, default=None,
|
||
help="annotate this specific DB only")
|
||
args = parser.parse_args()
|
||
|
||
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
|
||
if OUT_CSV.exists():
|
||
out = pd.read_csv(OUT_CSV)
|
||
else:
|
||
out = pd.DataFrame(columns=OUT_COLS)
|
||
done = set(zip(out["machine_name"], out["session_date"], out["session_time"]))
|
||
|
||
# Build the queue: every tracked DB referenced by the merged TSV that
|
||
# hasn't been picked yet.
|
||
tsv = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
|
||
queue = []
|
||
for col in ("training_db_path", "testing_db_path"):
|
||
for _, row in tsv.iterrows():
|
||
db = row[col]
|
||
if not isinstance(db, str) or not db:
|
||
continue
|
||
db_path = Path(db)
|
||
if not db_path.exists():
|
||
continue
|
||
parsed = parse_db_filename(db_path)
|
||
if parsed is None:
|
||
continue
|
||
session_date, session_time, _ = parsed
|
||
key = (row["machine_name"], session_date, session_time)
|
||
if key in done and not args.redo:
|
||
continue
|
||
queue.append((db_path, row["machine_name"], session_date, session_time))
|
||
|
||
# Dedup (a fly may reference the same DB for both training & testing).
|
||
seen = set()
|
||
deduped = []
|
||
for item in queue:
|
||
k = (item[1], item[2], item[3])
|
||
if k not in seen:
|
||
seen.add(k)
|
||
deduped.append(item)
|
||
queue = deduped
|
||
|
||
if args.db:
|
||
target = Path(args.db).resolve()
|
||
queue = [q for q in queue if Path(q[0]).resolve() == target]
|
||
if not queue:
|
||
sys.exit(f"DB not found in queue: {args.db}")
|
||
|
||
if args.limit:
|
||
queue = queue[: args.limit]
|
||
|
||
if not queue:
|
||
print("Nothing to pick. All eligible DBs already have a barrier_opening row.")
|
||
return
|
||
|
||
print(f"Picking barrier-opening for {len(queue)} videos.")
|
||
print("Window keys: click=set ENTER=save [/]=±1s {/}=±5s n=skip u=unusable r=reset q=quit")
|
||
|
||
saved = skipped = unusable = 0
|
||
for i, (db, machine_name, session_date, session_time) in enumerate(queue, 1):
|
||
prefix = f"[{i}/{len(queue)}] {machine_name} {session_date} {session_time}"
|
||
print(f"\n{prefix}")
|
||
|
||
traces = load_distance_traces(db)
|
||
auto_t = auto_suggest(traces)
|
||
initial = auto_t if auto_t is not None else 60.0
|
||
print(f" auto-suggest: "
|
||
f"{f'{auto_t:.1f}s' if auto_t is not None else '(none)'}")
|
||
|
||
result = show_picker(db, machine_name, session_date, session_time,
|
||
auto_t, initial)
|
||
|
||
if result is None or result == "skip":
|
||
skipped += 1
|
||
continue
|
||
if result == "quit":
|
||
print(" quit requested — saving what we have and exiting")
|
||
break
|
||
|
||
# Append + dedup on key + persist after each save (crash-safe).
|
||
new_row = pd.DataFrame([result])
|
||
out = pd.concat([
|
||
out[~((out.machine_name == result["machine_name"]) &
|
||
(out.session_date == result["session_date"]) &
|
||
(out.session_time == result["session_time"]))],
|
||
new_row,
|
||
], ignore_index=True)
|
||
out[OUT_COLS].to_csv(OUT_CSV, index=False)
|
||
if pd.isna(result["opening_s"]):
|
||
unusable += 1
|
||
else:
|
||
saved += 1
|
||
print(f" saved opening_s = {result['opening_s']} s")
|
||
|
||
print(f"\nDone: {saved} saved, {unusable} unusable, {skipped} skipped.")
|
||
print(f" → {OUT_CSV}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|