cupido/scripts/detect_barrier_opening.py
Giorgio Gilestro 847d2cbd1b Merge 2025-07-15 batch into the xlsx; tools to detect & re-track
- merge_2025_07_15_into_xlsx.py: pivot the legacy 2025_07_15_metadata_fixed.csv
  into the unified xlsx schema (one row per fly, training_date_time +
  testing_date_time). Backs up the xlsx before writing. 24 new rows
  across machines 076 / 139 / 145 / 268.
- pick_targets.py: --video flag to bypass the inventory's in_xlsx filter,
  so a specific mp4 can be picked outside the normal flow.
- explore_barrier_signal.py: visualises raw y(t), per-frame inter-fly
  distance, and sliding min/mean distance against a known
  barrier-opening time. Used for prototyping the detector.
- detect_barrier_opening.py: per-ROI sliding-window mean-distance
  change-point estimator (median across ROIs). Currently noisy on a
  one-video calibration set; will be re-tuned once the 4 missing
  2025-07-15 videos are re-tracked.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-01 10:28:25 +01:00

195 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Detect the barrier-opening time from tracking data.
Idea: before the barrier is removed, the two flies in a ROI are stuck on
opposite sides of a divider. Their inter-fly distance is bounded below
by ~the barrier width (typically 100250 px). After removal they can
walk up to each other and the minimum distance drops near zero. We
detect the first time the sliding-window MIN drops below a threshold
and call that the opening moment.
Per-ROI estimates are aggregated (median) across the 6 ROIs of one
video for a single video-level opening time. Disagreeing ROIs are
flagged so the analyst can double-check by eye.
This module exposes ``detect_opening_time(db_path)`` for callers, and
runs as a CLI to produce a TSV with one row per DB. Use::
python detect_barrier_opening.py --db <one.db> # single
python detect_barrier_opening.py # all DBs in TRACKING_OUTPUT_DIR
"""
from __future__ import annotations
import argparse
import sqlite3
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
from config import TRACKING_OUTPUT_DIR
# Tunables (calibrated on machine 076 / 16-03-10, ground truth 52s).
# We use windowed MEAN distance (not min) because the min is too easily
# tripped by isolated tracking artifacts in the first few seconds. The
# mean drops cleanly when the barrier opens because the flies start
# spending real time near each other instead of being held apart.
WINDOW_S = 30.0 # sliding-window length for the distance signal
STEP_S = 1.0 # step between window centres
SEARCH_END_S = 300.0 # opening always happens in the first 5 minutes
@dataclass
class RoiEstimate:
roi: int
opening_s: float | None
n_pairs: int # how many 2-fly frames we had
pre_min: float # median min-dist in pre-opening window (sanity)
post_min: float # median min-dist in post-opening window (sanity)
def per_frame_distance(df: pd.DataFrame) -> pd.DataFrame:
"""Frames with exactly 2 detections → (t_s, dist_px). Empty if none."""
if df.empty:
return df.assign(dist_px=np.nan).iloc[:0]
n = df.groupby("t").size()
two = n[n == 2].index
sub = df[df["t"].isin(two)].sort_values(["t", "id"])
if sub.empty:
return pd.DataFrame(columns=["t_s", "dist_px"])
pairs = sub.groupby("t").agg(
x1=("x", "first"), y1=("y", "first"),
x2=("x", "last"), y2=("y", "last"),
t_s=("t", "first"),
)
pairs["t_s"] = pairs["t_s"] / 1000.0
pairs["dist_px"] = np.hypot(pairs["x1"] - pairs["x2"], pairs["y1"] - pairs["y2"])
return pairs[["t_s", "dist_px"]].reset_index(drop=True)
def sliding_mean(dist: pd.DataFrame, window_s: float, step_s: float,
t_max: float) -> pd.DataFrame:
"""Return (mid_t, mean_dist) over sliding windows up to t_max."""
if dist.empty:
return pd.DataFrame(columns=["mid_t", "mean_dist"])
rows = []
for start in np.arange(0, t_max - window_s, step_s):
sub = dist[(dist["t_s"] >= start) & (dist["t_s"] < start + window_s)]
if sub.empty:
continue
rows.append({"mid_t": start + window_s / 2,
"mean_dist": sub["dist_px"].mean()})
return pd.DataFrame(rows)
def detect_one_roi(df_roi: pd.DataFrame) -> RoiEstimate:
"""Per-ROI detection.
Strategy: compute sliding-window mean distance, find the time of the
largest *drop* (windowed mean before vs after each candidate t).
The opening corresponds to the candidate that maximises (pre - post).
"""
roi_id = int(df_roi["ROI"].iloc[0]) if "ROI" in df_roi.columns and not df_roi.empty else -1
dist = per_frame_distance(df_roi)
n_pairs = len(dist)
if n_pairs < 100:
return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan)
smean = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
if len(smean) < 4:
return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan)
# Reason: scan candidate split points; for each, compute the median
# of the sliding mean BEFORE vs AFTER. The opening is the candidate
# that maximises (pre_median - post_median). Median (not mean) makes
# this robust to tracking artifacts at either end. Skip the very
# ends of the window so we have enough samples on each side.
pad = max(1, int(WINDOW_S / STEP_S)) # don't split too close to edges
if len(smean) < 2 * pad + 1:
return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan)
best_drop = -np.inf
best_t = None
best_pre = best_post = np.nan
for i in range(pad, len(smean) - pad):
pre = smean["mean_dist"].iloc[:i].median()
post = smean["mean_dist"].iloc[i:].median()
drop = pre - post
if drop > best_drop:
best_drop = drop
best_t = float(smean["mid_t"].iloc[i])
best_pre, best_post = float(pre), float(post)
# Reason: require a substantive drop — at least 30 px, and post must
# be below ~70% of pre. Otherwise the signal is too flat (probably
# the barrier was already open when recording started, or the
# session is unusable).
if best_drop < 30 or best_post > 0.7 * best_pre:
return RoiEstimate(roi_id, None, n_pairs, best_pre, best_post)
# Adjust: best_t was the centre of the post-window starting at index i;
# shift back by half a window so we report the actual transition moment.
opening_s = max(0.0, best_t - WINDOW_S / 2)
return RoiEstimate(roi_id, opening_s, n_pairs, best_pre, best_post)
def detect_opening_time(db_path: Path) -> dict:
"""Estimate barrier-opening time for one tracking DB.
Returns dict with:
- opening_s : float | None (median across ROIs that produced an estimate)
- per_roi : list[RoiEstimate]
- spread_s : max - min of per-ROI estimates (smaller = more agreement)
"""
estimates: list[RoiEstimate] = []
with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn:
for roi in range(1, 7):
try:
df = pd.read_sql_query(
f"SELECT t, x, y, id FROM ROI_{roi}", conn
)
except Exception:
estimates.append(RoiEstimate(roi, None, 0, np.nan, np.nan))
continue
df["ROI"] = roi
estimates.append(detect_one_roi(df))
valid = [e.opening_s for e in estimates if e.opening_s is not None]
if not valid:
return {"opening_s": None, "per_roi": estimates, "spread_s": None}
return {
"opening_s": float(np.median(valid)),
"per_roi": estimates,
"spread_s": float(np.max(valid) - np.min(valid)),
}
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--db", type=Path, help="single tracking DB to analyze")
args = parser.parse_args()
dbs = [args.db] if args.db else sorted(TRACKING_OUTPUT_DIR.glob("*_tracking.db"))
print(f"analyzing {len(dbs)} DB(s)\n")
for db in dbs:
result = detect_opening_time(db)
median_s = result["opening_s"]
spread = result["spread_s"]
print(f"{db.name}")
print(
f" median opening: "
f"{f'{median_s:.1f}s' if median_s is not None else 'no estimate'}"
f" spread: {f'{spread:.1f}s' if spread is not None else 'n/a'}"
)
for e in result["per_roi"]:
print(
f" ROI {e.roi}: "
f"{'-- ' if e.opening_s is None else f'{e.opening_s:5.1f}s'}"
f" pairs={e.n_pairs:>6d} pre={e.pre_min:5.1f} post={e.post_min:5.1f}"
)
print()
if __name__ == "__main__":
main()