diff --git a/scripts/config.py b/scripts/config.py index 18e89ef..9b72a29 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -26,6 +26,11 @@ VIDEO_INFO_TSV = DATA_VOLUME / "all_video_info_merged.tsv" # A small CSV listing every video file we know about (built locally). INVENTORY_CSV = DATA_METADATA / "video_inventory.csv" +# Hand-annotated barrier-opening times (output of the picker app). One +# row per testing session; columns: machine_name, session_date, +# session_time, opening_s, trim_first_s, bad_rois, analyst, notes. +BARRIER_OPENING_CSV = DATA_METADATA / "barrier_opening.csv" + # Where the ethoscope source tree is checked out (used by track_videos.py # and auto_detect_targets.py — host-side scripts that import ethoscope # from a local clone rather than from pip). Default assumes the standard diff --git a/scripts/load_roi_data.py b/scripts/load_roi_data.py index ee2263c..309709a 100644 --- a/scripts/load_roi_data.py +++ b/scripts/load_roi_data.py @@ -8,12 +8,48 @@ The TSV is the single source of truth for what data exists and how it maps to flies and conditions. """ +import re import sqlite3 from pathlib import Path import pandas as pd -from config import VIDEO_INFO_TSV +from config import BARRIER_OPENING_CSV, VIDEO_INFO_TSV + +# DB filenames start with `YYYY-MM-DD_HH-MM-SS__...` — pull the +# session date/time out so we can join against barrier_opening.csv. +_DB_TIMESTAMP_RE = re.compile(r"(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_") + + +def _session_key(db_path: str) -> tuple[str, str] | None: + """Extract (session_date, session_time) from a tracking DB filename.""" + if not isinstance(db_path, str) or not db_path: + return None + m = _DB_TIMESTAMP_RE.search(Path(db_path).name) + return (m.group(1), m.group(2)) if m else None + + +def _load_barrier_lookup(csv_path: Path) -> dict[tuple[str, str, str], dict]: + """Build (machine, session_date, session_time) → opening/bad_rois lookup. + + Returns an empty dict if the CSV is missing — callers should treat + that as "no per-session annotations available" rather than an error. + """ + if not Path(csv_path).exists(): + return {} + df = pd.read_csv(csv_path) + lookup: dict[tuple[str, str, str], dict] = {} + for r in df.itertuples(index=False): + bad = set() + if isinstance(r.bad_rois, str) and r.bad_rois.strip(): + bad = {int(x) for x in r.bad_rois.split(",") if x.strip()} + lookup[(r.machine_name, r.session_date, r.session_time)] = { + "opening_s": float(r.opening_s) if pd.notna(r.opening_s) else float("nan"), + "trim_first_s": float(r.trim_first_s) if pd.notna(r.trim_first_s) else 0.0, + "bad_rois": bad, + "unusable": pd.isna(r.opening_s), + } + return lookup # Reason: prefer the explicit Jupyter-widget tqdm when available (it # updates reliably in JupyterLab, where text \r-style bars sometimes @@ -66,6 +102,7 @@ def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None: def load_roi_data( meta: pd.DataFrame | None = None, progress: bool = True, + apply_barrier_filter: bool = True, ) -> pd.DataFrame: """Load ROI tracking data joined with experimental metadata. @@ -75,6 +112,14 @@ def load_roi_data( (``"training"`` or ``"testing"``). Rows with empty DB paths (unusable videos, or videos that didn't pass the completeness gate) are skipped. + Both training and testing reads are filtered against + ``barrier_opening.csv`` (the picker annotates both video types): + flies whose ROI never released (listed in ``bad_rois``) and entire + sessions flagged unusable are dropped. The session's ``opening_s`` + is stamped onto its samples so downstream code can compute + ``t_from_opening = t - opening_s``. Sessions missing from the CSV + are still loaded, but with ``opening_s = NaN``. + Args: meta: optional DataFrame with the same schema as ``all_video_info_merged.tsv``. Pass a filtered slice to load a @@ -82,11 +127,17 @@ def load_roi_data( Defaults to the full TSV. progress: show a tqdm progress bar (one tick per fly/ROI row). Defaults to True. Set False for silent batch jobs. + apply_barrier_filter: if True (default), drop session data for + flies whose barrier never opened and stamp ``opening_s`` + onto every sample. Set False to load raw data without any + barrier-derived filtering or columns. Returns: DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred, - has_interacted, session, `` — one row per tracking - sample. Empty if nothing could be loaded. + has_interacted, session, ROI, opening_s, `` — one row + per tracking sample. ``opening_s`` is NaN for sessions not + covered by ``barrier_opening.csv``. Empty if nothing could be + loaded. """ if meta is None: meta = pd.read_csv(VIDEO_INFO_TSV, sep="\t") @@ -97,8 +148,12 @@ def load_roi_data( if "include" in meta.columns: meta = meta[meta["include"].astype(bool)] + barrier_lookup = _load_barrier_lookup(BARRIER_OPENING_CSV) if apply_barrier_filter else {} + db_cache: dict = {} chunks: list[pd.DataFrame] = [] + n_skipped_bad_roi = 0 + n_skipped_unusable = 0 n_rows = len(meta) if progress: @@ -125,7 +180,28 @@ def load_roi_data( for row in meta.itertuples(index=False): for session in ("training", "testing"): pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}") - conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache) + db_path = getattr(row, f"{session}_db_path") + + # The picker annotates barrier_opening per video, and both + # the training and testing videos have their own entries. + # Apply the same per-session filter to both. + opening_s = float("nan") + if barrier_lookup: + key = _session_key(db_path) + if key is not None: + bo = barrier_lookup.get((row.machine_name, key[0], key[1])) + if bo is not None: + if bo["unusable"]: + n_skipped_unusable += 1 + pbar.update(1) + continue + if int(row.roi) in bo["bad_rois"]: + n_skipped_bad_roi += 1 + pbar.update(1) + continue + opening_s = bo["opening_s"] + + conn = _open_ro(db_path, db_cache) if conn is None: pbar.update(1) continue @@ -141,6 +217,7 @@ def load_roi_data( continue df["session"] = session df["ROI"] = int(row.roi) + df["opening_s"] = opening_s for col in _META_COLS: df[col] = getattr(row, col) chunks.append(df) @@ -148,6 +225,13 @@ def load_roi_data( pbar.close() + if apply_barrier_filter and (n_skipped_bad_roi or n_skipped_unusable): + print( + f"Barrier filter: dropped {n_skipped_bad_roi} ROI loads (barrier " + f"never opened) and {n_skipped_unusable} unusable sessions.", + flush=True, + ) + for conn in db_cache.values(): if conn is not None: conn.close()