- merge_2025_07_15_into_xlsx.py: pivot the legacy 2025_07_15_metadata_fixed.csv into the unified xlsx schema (one row per fly, training_date_time + testing_date_time). Backs up the xlsx before writing. 24 new rows across machines 076 / 139 / 145 / 268. - pick_targets.py: --video flag to bypass the inventory's in_xlsx filter, so a specific mp4 can be picked outside the normal flow. - explore_barrier_signal.py: visualises raw y(t), per-frame inter-fly distance, and sliding min/mean distance against a known barrier-opening time. Used for prototyping the detector. - detect_barrier_opening.py: per-ROI sliding-window mean-distance change-point estimator (median across ROIs). Currently noisy on a one-video calibration set; will be re-tuned once the 4 missing 2025-07-15 videos are re-tracked. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
195 lines
7.6 KiB
Python
195 lines
7.6 KiB
Python
"""Detect the barrier-opening time from tracking data.
|
||
|
||
Idea: before the barrier is removed, the two flies in a ROI are stuck on
|
||
opposite sides of a divider. Their inter-fly distance is bounded below
|
||
by ~the barrier width (typically 100–250 px). After removal they can
|
||
walk up to each other and the minimum distance drops near zero. We
|
||
detect the first time the sliding-window MIN drops below a threshold
|
||
and call that the opening moment.
|
||
|
||
Per-ROI estimates are aggregated (median) across the 6 ROIs of one
|
||
video for a single video-level opening time. Disagreeing ROIs are
|
||
flagged so the analyst can double-check by eye.
|
||
|
||
This module exposes ``detect_opening_time(db_path)`` for callers, and
|
||
runs as a CLI to produce a TSV with one row per DB. Use::
|
||
|
||
python detect_barrier_opening.py --db <one.db> # single
|
||
python detect_barrier_opening.py # all DBs in TRACKING_OUTPUT_DIR
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import sqlite3
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
from config import TRACKING_OUTPUT_DIR
|
||
|
||
# Tunables (calibrated on machine 076 / 16-03-10, ground truth 52s).
|
||
# We use windowed MEAN distance (not min) because the min is too easily
|
||
# tripped by isolated tracking artifacts in the first few seconds. The
|
||
# mean drops cleanly when the barrier opens because the flies start
|
||
# spending real time near each other instead of being held apart.
|
||
WINDOW_S = 30.0 # sliding-window length for the distance signal
|
||
STEP_S = 1.0 # step between window centres
|
||
SEARCH_END_S = 300.0 # opening always happens in the first 5 minutes
|
||
|
||
|
||
@dataclass
|
||
class RoiEstimate:
|
||
roi: int
|
||
opening_s: float | None
|
||
n_pairs: int # how many 2-fly frames we had
|
||
pre_min: float # median min-dist in pre-opening window (sanity)
|
||
post_min: float # median min-dist in post-opening window (sanity)
|
||
|
||
|
||
def per_frame_distance(df: pd.DataFrame) -> pd.DataFrame:
|
||
"""Frames with exactly 2 detections → (t_s, dist_px). Empty if none."""
|
||
if df.empty:
|
||
return df.assign(dist_px=np.nan).iloc[:0]
|
||
n = df.groupby("t").size()
|
||
two = n[n == 2].index
|
||
sub = df[df["t"].isin(two)].sort_values(["t", "id"])
|
||
if sub.empty:
|
||
return pd.DataFrame(columns=["t_s", "dist_px"])
|
||
pairs = sub.groupby("t").agg(
|
||
x1=("x", "first"), y1=("y", "first"),
|
||
x2=("x", "last"), y2=("y", "last"),
|
||
t_s=("t", "first"),
|
||
)
|
||
pairs["t_s"] = pairs["t_s"] / 1000.0
|
||
pairs["dist_px"] = np.hypot(pairs["x1"] - pairs["x2"], pairs["y1"] - pairs["y2"])
|
||
return pairs[["t_s", "dist_px"]].reset_index(drop=True)
|
||
|
||
|
||
def sliding_mean(dist: pd.DataFrame, window_s: float, step_s: float,
|
||
t_max: float) -> pd.DataFrame:
|
||
"""Return (mid_t, mean_dist) over sliding windows up to t_max."""
|
||
if dist.empty:
|
||
return pd.DataFrame(columns=["mid_t", "mean_dist"])
|
||
rows = []
|
||
for start in np.arange(0, t_max - window_s, step_s):
|
||
sub = dist[(dist["t_s"] >= start) & (dist["t_s"] < start + window_s)]
|
||
if sub.empty:
|
||
continue
|
||
rows.append({"mid_t": start + window_s / 2,
|
||
"mean_dist": sub["dist_px"].mean()})
|
||
return pd.DataFrame(rows)
|
||
|
||
|
||
def detect_one_roi(df_roi: pd.DataFrame) -> RoiEstimate:
|
||
"""Per-ROI detection.
|
||
|
||
Strategy: compute sliding-window mean distance, find the time of the
|
||
largest *drop* (windowed mean before vs after each candidate t).
|
||
The opening corresponds to the candidate that maximises (pre - post).
|
||
"""
|
||
roi_id = int(df_roi["ROI"].iloc[0]) if "ROI" in df_roi.columns and not df_roi.empty else -1
|
||
dist = per_frame_distance(df_roi)
|
||
n_pairs = len(dist)
|
||
if n_pairs < 100:
|
||
return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan)
|
||
|
||
smean = sliding_mean(dist, WINDOW_S, STEP_S, SEARCH_END_S)
|
||
if len(smean) < 4:
|
||
return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan)
|
||
|
||
# Reason: scan candidate split points; for each, compute the median
|
||
# of the sliding mean BEFORE vs AFTER. The opening is the candidate
|
||
# that maximises (pre_median - post_median). Median (not mean) makes
|
||
# this robust to tracking artifacts at either end. Skip the very
|
||
# ends of the window so we have enough samples on each side.
|
||
pad = max(1, int(WINDOW_S / STEP_S)) # don't split too close to edges
|
||
if len(smean) < 2 * pad + 1:
|
||
return RoiEstimate(roi_id, None, n_pairs, np.nan, np.nan)
|
||
|
||
best_drop = -np.inf
|
||
best_t = None
|
||
best_pre = best_post = np.nan
|
||
for i in range(pad, len(smean) - pad):
|
||
pre = smean["mean_dist"].iloc[:i].median()
|
||
post = smean["mean_dist"].iloc[i:].median()
|
||
drop = pre - post
|
||
if drop > best_drop:
|
||
best_drop = drop
|
||
best_t = float(smean["mid_t"].iloc[i])
|
||
best_pre, best_post = float(pre), float(post)
|
||
|
||
# Reason: require a substantive drop — at least 30 px, and post must
|
||
# be below ~70% of pre. Otherwise the signal is too flat (probably
|
||
# the barrier was already open when recording started, or the
|
||
# session is unusable).
|
||
if best_drop < 30 or best_post > 0.7 * best_pre:
|
||
return RoiEstimate(roi_id, None, n_pairs, best_pre, best_post)
|
||
|
||
# Adjust: best_t was the centre of the post-window starting at index i;
|
||
# shift back by half a window so we report the actual transition moment.
|
||
opening_s = max(0.0, best_t - WINDOW_S / 2)
|
||
return RoiEstimate(roi_id, opening_s, n_pairs, best_pre, best_post)
|
||
|
||
|
||
def detect_opening_time(db_path: Path) -> dict:
|
||
"""Estimate barrier-opening time for one tracking DB.
|
||
|
||
Returns dict with:
|
||
- opening_s : float | None (median across ROIs that produced an estimate)
|
||
- per_roi : list[RoiEstimate]
|
||
- spread_s : max - min of per-ROI estimates (smaller = more agreement)
|
||
"""
|
||
estimates: list[RoiEstimate] = []
|
||
with sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) as conn:
|
||
for roi in range(1, 7):
|
||
try:
|
||
df = pd.read_sql_query(
|
||
f"SELECT t, x, y, id FROM ROI_{roi}", conn
|
||
)
|
||
except Exception:
|
||
estimates.append(RoiEstimate(roi, None, 0, np.nan, np.nan))
|
||
continue
|
||
df["ROI"] = roi
|
||
estimates.append(detect_one_roi(df))
|
||
|
||
valid = [e.opening_s for e in estimates if e.opening_s is not None]
|
||
if not valid:
|
||
return {"opening_s": None, "per_roi": estimates, "spread_s": None}
|
||
return {
|
||
"opening_s": float(np.median(valid)),
|
||
"per_roi": estimates,
|
||
"spread_s": float(np.max(valid) - np.min(valid)),
|
||
}
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description=__doc__)
|
||
parser.add_argument("--db", type=Path, help="single tracking DB to analyze")
|
||
args = parser.parse_args()
|
||
|
||
dbs = [args.db] if args.db else sorted(TRACKING_OUTPUT_DIR.glob("*_tracking.db"))
|
||
print(f"analyzing {len(dbs)} DB(s)\n")
|
||
for db in dbs:
|
||
result = detect_opening_time(db)
|
||
median_s = result["opening_s"]
|
||
spread = result["spread_s"]
|
||
print(f"{db.name}")
|
||
print(
|
||
f" median opening: "
|
||
f"{f'{median_s:.1f}s' if median_s is not None else 'no estimate'}"
|
||
f" spread: {f'{spread:.1f}s' if spread is not None else 'n/a'}"
|
||
)
|
||
for e in result["per_roi"]:
|
||
print(
|
||
f" ROI {e.roi}: "
|
||
f"{'-- ' if e.opening_s is None else f'{e.opening_s:5.1f}s'}"
|
||
f" pairs={e.n_pairs:>6d} pre={e.pre_min:5.1f} post={e.post_min:5.1f}"
|
||
)
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|