cupido/scripts/explore_barrier_signal.py
Giorgio Gilestro 847d2cbd1b Merge 2025-07-15 batch into the xlsx; tools to detect & re-track
- merge_2025_07_15_into_xlsx.py: pivot the legacy 2025_07_15_metadata_fixed.csv
  into the unified xlsx schema (one row per fly, training_date_time +
  testing_date_time). Backs up the xlsx before writing. 24 new rows
  across machines 076 / 139 / 145 / 268.
- pick_targets.py: --video flag to bypass the inventory's in_xlsx filter,
  so a specific mp4 can be picked outside the normal flow.
- explore_barrier_signal.py: visualises raw y(t), per-frame inter-fly
  distance, and sliding min/mean distance against a known
  barrier-opening time. Used for prototyping the detector.
- detect_barrier_opening.py: per-ROI sliding-window mean-distance
  change-point estimator (median across ROIs). Currently noisy on a
  one-video calibration set; will be re-tuned once the 4 missing
  2025-07-15 videos are re-tracked.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-01 10:28:25 +01:00

143 lines
5.2 KiB
Python

"""Look at the tracking signal around the known barrier-opening time.
Loads one tracking DB whose opening time we know (from
2025_07_15_barrier_opening.csv) and plots a few candidate signals against
time, with a vertical line at the ground-truth opening:
1. Y position of each detection (raw scatter)
2. Sliding-window Y range (max - min over a window)
3. Sliding-window |y - roi_midline| (mean distance from midline)
The hope is one of these has a clean step-change at t = opening_time
that's robustly detectable across ROIs.
Run:
python explore_barrier_signal.py
Outputs:
figures/barrier_signal_<machine>_<time>.png
"""
from __future__ import annotations
import sqlite3
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from config import FIGURES, TRACKING_OUTPUT_DIR
# Ground-truth case: machine 076, session 16-03-10 → opening = 52 s.
DB_NAME = "2025-07-15_16-03-10_076e2825a7274661bd0697c42d6fa4c0__1920x1088@25fps-28q_merged_tracking.db"
KNOWN_OPENING_S = 52.0
WINDOW_S = 10.0 # sliding-window length for the derived signals
def load_roi(db_path: Path, roi: int) -> pd.DataFrame:
"""Read one ROI table; return DataFrame with t in seconds."""
with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn:
df = pd.read_sql_query(f"SELECT t, x, y, w, h, id FROM ROI_{roi}", conn)
df["t_s"] = df["t"] / 1000.0
return df
def per_frame_distance(df: pd.DataFrame) -> pd.DataFrame:
"""For frames with exactly 2 detections, return (t_s, distance)."""
g = df.groupby("t")
n_per_frame = g.size()
two_fly_t = n_per_frame[n_per_frame == 2].index
sub = df[df["t"].isin(two_fly_t)].sort_values(["t", "id"])
pairs = sub.groupby("t").agg(
x1=("x", "first"), y1=("y", "first"),
x2=("x", "last"), y2=("y", "last"),
t_s=("t_s", "first"),
)
pairs["dist_px"] = np.hypot(pairs["x1"] - pairs["x2"], pairs["y1"] - pairs["y2"])
return pairs.reset_index(drop=True)
def sliding_signals(df: pd.DataFrame, dist: pd.DataFrame,
window_s: float, step_s: float = 1.0) -> pd.DataFrame:
"""Per-window summary signals."""
if df.empty:
return pd.DataFrame()
midline = df["y"].median()
t0, t1 = df["t_s"].min(), df["t_s"].max()
rows = []
for start in np.arange(t0, t1 - window_s, step_s):
sub = df [(df ["t_s"] >= start) & (df ["t_s"] < start + window_s)]
sub_d = dist[(dist["t_s"] >= start) & (dist["t_s"] < start + window_s)]
if sub.empty:
continue
rows.append({
"mid_t": start + window_s / 2,
"y_range": sub["y"].max() - sub["y"].min(),
"y_mid_dist": (sub["y"] - midline).abs().mean(),
"min_dist": sub_d["dist_px"].min() if not sub_d.empty else np.nan,
"mean_dist": sub_d["dist_px"].mean() if not sub_d.empty else np.nan,
})
return pd.DataFrame(rows)
def main() -> None:
db = TRACKING_OUTPUT_DIR / DB_NAME
if not db.exists():
raise FileNotFoundError(db)
fig, axes = plt.subplots(6, 3, figsize=(16, 22), sharex=True)
# Zoom: only plot first 200 s — opening is < 90s in all known cases.
XLIM = (0, 200)
for roi in range(1, 7):
df = load_roi(db, roi)
dist = per_frame_distance(df)
windowed = sliding_signals(df, dist, WINDOW_S)
ax_raw, ax_dist, ax_min = axes[roi - 1]
# 1) raw y-positions, zoomed on the early window
ax_raw.scatter(df["t_s"], df["y"], s=0.5, alpha=0.4, c="steelblue")
ax_raw.axvline(KNOWN_OPENING_S, color="red", lw=1, ls="--",
label=f"opening = {KNOWN_OPENING_S}s")
ax_raw.set_ylabel(f"ROI {roi}\ny (px)")
ax_raw.set_xlim(*XLIM)
if roi == 1:
ax_raw.set_title("Raw y(t)")
ax_raw.legend(loc="upper right", fontsize=8)
# 2) raw inter-fly distance (per frame)
ax_dist.plot(dist["t_s"], dist["dist_px"], lw=0.4, alpha=0.6, color="steelblue")
ax_dist.axvline(KNOWN_OPENING_S, color="red", lw=1, ls="--")
ax_dist.set_ylabel("dist (px)")
ax_dist.set_xlim(*XLIM)
if roi == 1:
ax_dist.set_title("Per-frame inter-fly distance")
# 3) sliding window: MIN inter-fly distance in window
ax_min.plot(windowed["mid_t"], windowed["min_dist"], color="darkgreen", label="min")
ax_min.plot(windowed["mid_t"], windowed["mean_dist"], color="purple", label="mean", lw=0.8)
ax_min.axvline(KNOWN_OPENING_S, color="red", lw=1, ls="--")
ax_min.set_ylabel("dist (px)")
ax_min.set_xlim(*XLIM)
if roi == 1:
ax_min.set_title(f"min/mean inter-fly distance over {WINDOW_S}s window")
ax_min.legend(loc="upper right", fontsize=8)
for ax in axes[-1]:
ax.set_xlabel("time (s)")
fig.suptitle(
f"Barrier-opening signal exploration\n"
f"machine 076, session 16-03-10 · ground truth: {KNOWN_OPENING_S}s",
fontsize=14,
)
fig.tight_layout()
FIGURES.mkdir(parents=True, exist_ok=True)
out = FIGURES / "barrier_signal_076_16-03-10.png"
fig.savefig(out, dpi=120, bbox_inches="tight")
print(f"saved {out}")
if __name__ == "__main__":
main()