cupido/scripts/pick_barrier.py
Giorgio Gilestro b46c4ac1ba Add pick_barrier.py interactive annotator + seed CSV with 2025-07-15
pick_barrier.py loops over every tracked DB referenced by the merged
TSV, plots windowed mean inter-fly distance for all 6 ROIs in a single
figure, and lets the analyst click the moment the barrier opens. Saves
to data/metadata/barrier_opening.csv after each pick (crash-safe).
Auto-detector best-effort guess shown as orange dotted line — the
analyst always has the final say.

Output schema:
    machine_name, session_date, session_time, opening_s, trim_first_s, notes

`trim_first_s` lets us record misframed starts so downstream code can
ignore the affected window. The 5 2025-07-15 entries are seeded from
the original legacy CSV so they're not re-picked.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-01 11:58:54 +01:00

309 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Interactive picker for barrier-opening time per tracked video.
Loops through tracked DBs that don't yet have a barrier-opening
annotation. For each, plots the windowed mean inter-fly distance for
all 6 ROIs over the first 5 minutes and lets the analyst click the
moment the barrier opens (when most flies start coming close together).
The auto-detector's best-effort guess is shown as a starting position.
Output: data/metadata/barrier_opening.csv with columns
machine_name, session_date, session_time, opening_s, trim_first_s, notes
`opening_s` is the moment of barrier opening, measured from the start
of the recording (NOT the start of any trimmed copy). `trim_first_s`
is an optional annotation for videos with a misframed start that
should be ignored by analysis (defaults to 0).
Window keys:
click place the opening cursor at that time
ENTER save and advance
[, ] shift cursor by 1 s left / right
{, } shift cursor by 5 s left / right
n skip this video for THIS run (no row written)
u mark this video unusable (writes opening_s = NaN, notes = "unusable")
r reset cursor to the auto-detected position
q / ESC save+quit
Usage:
python pick_barrier.py
python pick_barrier.py --redo # re-pick videos that already have a row
python pick_barrier.py --limit 10
"""
from __future__ import annotations
import argparse
import sqlite3
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from config import DATA_METADATA, VIDEO_INFO_TSV
from detect_barrier_opening import (
SEARCH_END_S, STEP_S, WINDOW_S,
per_frame_distance, sliding_mean,
)
OUT_CSV = DATA_METADATA / "barrier_opening.csv"
OUT_COLS = ["machine_name", "session_date", "session_time",
"opening_s", "trim_first_s", "notes"]
def parse_db_filename(db_path: Path) -> tuple[str, str, str] | None:
"""Pull (date, time, machine_uuid) out of a tracking DB filename."""
import re
m = re.match(
r"^(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_([0-9a-f]{32})__",
db_path.name,
)
if not m:
return None
return m.group(1), m.group(2), m.group(3)
def load_distance_traces(db_path: Path) -> dict[int, pd.DataFrame]:
"""For each ROI 1..6, return windowed-mean DF; empty if ROI missing."""
out: dict[int, pd.DataFrame] = {}
with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn:
for roi in range(1, 7):
try:
df = pd.read_sql_query(
f"SELECT t, x, y, id FROM ROI_{roi}", conn
)
except Exception:
out[roi] = pd.DataFrame()
continue
dist = per_frame_distance(df)
out[roi] = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
return out
def auto_suggest(traces: dict[int, pd.DataFrame]) -> float | None:
"""Median of per-ROI biggest-drop times. Returns None if too noisy."""
candidates = []
for roi, smean in traces.items():
if len(smean) < 30:
continue
# Find the time of the largest decrease in median(pre)median(post).
pad = max(1, int(WINDOW_S / STEP_S))
if len(smean) < 2 * pad + 1:
continue
best_drop = -np.inf
best_t = None
for i in range(pad, len(smean) - pad):
pre = smean["mean_dist"].iloc[:i].median()
post = smean["mean_dist"].iloc[i:].median()
drop = pre - post
if drop > best_drop:
best_drop = drop
best_t = float(smean["mid_t"].iloc[i])
if best_drop > 30 and best_t is not None:
candidates.append(best_t)
if not candidates:
return None
return float(np.median(candidates))
def show_picker(
db_path: Path,
machine_name: str,
session_date: str,
session_time: str,
auto_t: float | None,
initial_t: float,
) -> dict | None:
"""Open the picker window. Returns a dict ready for OUT_CSV, or None to skip."""
traces = load_distance_traces(db_path)
if all(s.empty for s in traces.values()):
print(f" ! no usable ROI traces in {db_path.name}; skipping")
return None
fig, axes = plt.subplots(6, 1, figsize=(13, 12), sharex=True)
fig.suptitle(
f"{machine_name} {session_date} {session_time}\n"
f"click ↦ set opening · ENTER save · "
f"[/] ±1s · {{/}} ±5s · n skip · u unusable · r reset · q quit",
fontsize=10,
)
state: dict = {"t": float(initial_t), "auto": auto_t, "result": None}
def redraw():
for ax, (roi, smean) in zip(axes, sorted(traces.items())):
ax.cla()
if smean.empty:
ax.text(0.5, 0.5, f"ROI {roi}: no data",
transform=ax.transAxes, ha="center", va="center", color="grey")
else:
ax.plot(smean["mid_t"], smean["mean_dist"], color="steelblue", lw=1.0)
ax.set_ylabel(f"ROI {roi}")
if state["auto"] is not None:
ax.axvline(state["auto"], color="orange", ls=":", lw=0.8, alpha=0.8)
ax.axvline(state["t"], color="red", lw=1.5)
ax.set_xlim(0, SEARCH_END_S)
ax.grid(True, alpha=0.3)
axes[-1].set_xlabel("time (s)")
axes[0].set_title(f"orange dotted = auto-suggested · red = current pick: {state['t']:.1f} s",
fontsize=9)
fig.canvas.draw_idle()
def on_click(event):
if event.inaxes in axes and event.xdata is not None:
state["t"] = max(0.0, min(SEARCH_END_S, float(event.xdata)))
redraw()
def on_key(event):
k = event.key
if k == "enter":
state["result"] = {
"machine_name": machine_name,
"session_date": session_date,
"session_time": session_time,
"opening_s": round(state["t"], 1),
"trim_first_s": 0,
"notes": "",
}
plt.close(fig)
elif k == "n":
state["result"] = "skip"
plt.close(fig)
elif k == "u":
state["result"] = {
"machine_name": machine_name,
"session_date": session_date,
"session_time": session_time,
"opening_s": np.nan,
"trim_first_s": 0,
"notes": "unusable",
}
plt.close(fig)
elif k == "r" and state["auto"] is not None:
state["t"] = state["auto"]
redraw()
elif k in ("[", "]"):
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-1 if k == "[" else 1)))
redraw()
elif k in ("{", "}"):
state["t"] = max(0.0, min(SEARCH_END_S, state["t"] + (-5 if k == "{" else 5)))
redraw()
elif k in ("q", "escape"):
state["result"] = "quit"
plt.close(fig)
fig.canvas.mpl_connect("button_press_event", on_click)
fig.canvas.mpl_connect("key_press_event", on_key)
redraw()
plt.show()
return state["result"]
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--redo", action="store_true",
help="re-pick videos that already have a row in the output CSV")
parser.add_argument("--limit", type=int, default=None,
help="only process the first N videos")
parser.add_argument("--db", type=Path, default=None,
help="annotate this specific DB only")
args = parser.parse_args()
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
if OUT_CSV.exists():
out = pd.read_csv(OUT_CSV)
else:
out = pd.DataFrame(columns=OUT_COLS)
done = set(zip(out["machine_name"], out["session_date"], out["session_time"]))
# Build the queue: every tracked DB referenced by the merged TSV that
# hasn't been picked yet.
tsv = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
queue = []
for col in ("training_db_path", "testing_db_path"):
for _, row in tsv.iterrows():
db = row[col]
if not isinstance(db, str) or not db:
continue
db_path = Path(db)
if not db_path.exists():
continue
parsed = parse_db_filename(db_path)
if parsed is None:
continue
session_date, session_time, _ = parsed
key = (row["machine_name"], session_date, session_time)
if key in done and not args.redo:
continue
queue.append((db_path, row["machine_name"], session_date, session_time))
# Dedup (a fly may reference the same DB for both training & testing).
seen = set()
deduped = []
for item in queue:
k = (item[1], item[2], item[3])
if k not in seen:
seen.add(k)
deduped.append(item)
queue = deduped
if args.db:
target = Path(args.db).resolve()
queue = [q for q in queue if Path(q[0]).resolve() == target]
if not queue:
sys.exit(f"DB not found in queue: {args.db}")
if args.limit:
queue = queue[: args.limit]
if not queue:
print("Nothing to pick. All eligible DBs already have a barrier_opening row.")
return
print(f"Picking barrier-opening for {len(queue)} videos.")
print("Window keys: click=set ENTER=save [/]=±1s {/}=±5s n=skip u=unusable r=reset q=quit")
saved = skipped = unusable = 0
for i, (db, machine_name, session_date, session_time) in enumerate(queue, 1):
prefix = f"[{i}/{len(queue)}] {machine_name} {session_date} {session_time}"
print(f"\n{prefix}")
traces = load_distance_traces(db)
auto_t = auto_suggest(traces)
initial = auto_t if auto_t is not None else 60.0
print(f" auto-suggest: "
f"{f'{auto_t:.1f}s' if auto_t is not None else '(none)'}")
result = show_picker(db, machine_name, session_date, session_time,
auto_t, initial)
if result is None or result == "skip":
skipped += 1
continue
if result == "quit":
print(" quit requested — saving what we have and exiting")
break
# Append + dedup on key + persist after each save (crash-safe).
new_row = pd.DataFrame([result])
out = pd.concat([
out[~((out.machine_name == result["machine_name"]) &
(out.session_date == result["session_date"]) &
(out.session_time == result["session_time"]))],
new_row,
], ignore_index=True)
out[OUT_COLS].to_csv(OUT_CSV, index=False)
if pd.isna(result["opening_s"]):
unusable += 1
else:
saved += 1
print(f" saved opening_s = {result['opening_s']} s")
print(f"\nDone: {saved} saved, {unusable} unusable, {skipped} skipped.")
print(f"{OUT_CSV}")
if __name__ == "__main__":
main()