"""Detect the barrier-opening time from tracking data. Idea: before the barrier is removed, the two flies in a ROI are stuck on opposite sides of a divider. Their inter-fly distance is bounded below by ~the barrier width (typically 100–250 px). After removal they can walk up to each other and the minimum distance drops near zero. We detect the first time the sliding-window MIN drops below a threshold and call that the opening moment. Per-ROI estimates are aggregated (median) across the 6 ROIs of one video for a single video-level opening time. Disagreeing ROIs are flagged so the analyst can double-check by eye. This module exposes ``detect_opening_time(db_path)`` for callers, and runs as a CLI to produce a TSV with one row per DB. Use:: python detect_barrier_opening.py --db # single python detect_barrier_opening.py # all DBs in TRACKING_OUTPUT_DIR """ from __future__ import annotations import argparse import sqlite3 from dataclasses import dataclass from pathlib import Path import numpy as np import pandas as pd from config import TRACKING_OUTPUT_DIR # Tunables (calibrated on machine 076 / 16-03-10, ground truth 52s). # We use windowed MEAN distance (not min) because the min is too easily # tripped by isolated tracking artifacts in the first few seconds. The # mean drops cleanly when the barrier opens because the flies start # spending real time near each other instead of being held apart. WINDOW_S = 30.0 # sliding-window length for the distance signal STEP_S = 1.0 # step between window centres SEARCH_END_S = 300.0 # opening always happens in the first 5 minutes @dataclass class RoiEstimate: roi: int opening_s: float | None n_pairs: int # how many 2-fly frames we had pre_min: float # median min-dist in pre-opening window (sanity) post_min: float # median min-dist in post-opening window (sanity) def per_frame_distance(df: pd.DataFrame) -> pd.DataFrame: """Frames with exactly 2 detections → (t_s, dist_px). Empty if none.""" if df.empty: return df.assign(dist_px=np.nan).iloc[:0] n = df.groupby("t").size() two = n[n == 2].index sub = df[df["t"].isin(two)].sort_values(["t", "id"]) if sub.empty: return pd.DataFrame(columns=["t_s", "dist_px"]) pairs = sub.groupby("t").agg( x1=("x", "first"), y1=("y", "first"), x2=("x", "last"), y2=("y", "last"), t_s=("t", "first"), ) pairs["t_s"] = pairs["t_s"] / 1000.0 pairs["dist_px"] = np.hypot(pairs["x1"] - pairs["x2"], pairs["y1"] - pairs["y2"]) return pairs[["t_s", "dist_px"]].reset_index(drop=True) def sliding_mean(dist: pd.DataFrame, window_s: float, step_s: float, t_max: float) -> pd.DataFrame: """Return (mid_t, mean_dist) over sliding windows up to t_max.""" if dist.empty: return pd.DataFrame(columns=["mid_t", "mean_dist"]) rows = [] for start in np.arange(0, t_max - window_s, step_s): sub = dist[(dist["t_s"] >= start) & (dist["t_s"] < start + window_s)] if sub.empty: continue rows.append({"mid_t": start + window_s / 2, "mean_dist": sub["dist_px"].mean()}) return pd.DataFrame(rows) def detect_one_roi(df_roi: pd.DataFrame) -> RoiEstimate: """Per-ROI detection. Strategy: compute sliding-window mean distance, find the time of the largest *drop* (windowed mean before vs after each candidate t). The opening corresponds to the candidate that maximises (pre - post). """ roi_id = int(df_roi["ROI"].iloc[0]) if "ROI" in df_roi.columns and not df_roi.empty else -1 dist = per_frame_distance(df_roi) n_pairs = len(dist) if n_pairs < 100: return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan) smean = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S) if len(smean) < 4: return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan) # Reason: scan candidate split points; for each, compute the median # of the sliding mean BEFORE vs AFTER. The opening is the candidate # that maximises (pre_median - post_median). Median (not mean) makes # this robust to tracking artifacts at either end. Skip the very # ends of the window so we have enough samples on each side. pad = max(1, int(WINDOW_S / STEP_S)) # don't split too close to edges if len(smean) < 2 * pad + 1: return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan) best_drop = -np.inf best_t = None best_pre = best_post = np.nan for i in range(pad, len(smean) - pad): pre = smean["mean_dist"].iloc[:i].median() post = smean["mean_dist"].iloc[i:].median() drop = pre - post if drop > best_drop: best_drop = drop best_t = float(smean["mid_t"].iloc[i]) best_pre, best_post = float(pre), float(post) # Reason: require a substantive drop — at least 30 px, and post must # be below ~70% of pre. Otherwise the signal is too flat (probably # the barrier was already open when recording started, or the # session is unusable). if best_drop < 30 or best_post > 0.7 * best_pre: return RoiEstimate(roi_id, None, n_pairs, best_pre, best_post) # Adjust: best_t was the centre of the post-window starting at index i; # shift back by half a window so we report the actual transition moment. opening_s = max(0.0, best_t - WINDOW_S / 2) return RoiEstimate(roi_id, opening_s, n_pairs, best_pre, best_post) def detect_opening_time(db_path: Path) -> dict: """Estimate barrier-opening time for one tracking DB. Returns dict with: - opening_s : float | None (median across ROIs that produced an estimate) - per_roi : list[RoiEstimate] - spread_s : max - min of per-ROI estimates (smaller = more agreement) """ estimates: list[RoiEstimate] = [] with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn: for roi in range(1, 7): try: df = pd.read_sql_query( f"SELECT t, x, y, id FROM ROI_{roi}", conn ) except Exception: estimates.append(RoiEstimate(roi, None, 0, np.nan, np.nan)) continue df["ROI"] = roi estimates.append(detect_one_roi(df)) valid = [e.opening_s for e in estimates if e.opening_s is not None] if not valid: return {"opening_s": None, "per_roi": estimates, "spread_s": None} return { "opening_s": float(np.median(valid)), "per_roi": estimates, "spread_s": float(np.max(valid) - np.min(valid)), } def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--db", type=Path, help="single tracking DB to analyze") args = parser.parse_args() dbs = [args.db] if args.db else sorted(TRACKING_OUTPUT_DIR.glob("*_tracking.db")) print(f"analyzing {len(dbs)} DB(s)\n") for db in dbs: result = detect_opening_time(db) median_s = result["opening_s"] spread = result["spread_s"] print(f"{db.name}") print( f" median opening: " f"{f'{median_s:.1f}s' if median_s is not None else 'no estimate'}" f" spread: {f'{spread:.1f}s' if spread is not None else 'n/a'}" ) for e in result["per_roi"]: print( f" ROI {e.roi}: " f"{'-- ' if e.opening_s is None else f'{e.opening_s:5.1f}s'}" f" pairs={e.n_pairs:>6d} pre={e.pre_min:5.1f} post={e.post_min:5.1f}" ) print() if __name__ == "__main__": main()