"""Load ROI tracking data from all sessions into one DataFrame. Drives off the merged TSV (one row per ROI/fly across training + testing phases). For each TSV row, opens the corresponding tracking DB and pulls the matching ROI table, then attaches the experimental metadata. The TSV is the single source of truth for what data exists and how it maps to flies and conditions. """ import re import sqlite3 from pathlib import Path import pandas as pd from config import BARRIER_OPENING_CSV, VIDEO_INFO_TSV # DB filenames start with `YYYY-MM-DD_HH-MM-SS__...` — pull the # session date/time out so we can join against barrier_opening.csv. _DB_TIMESTAMP_RE = re.compile(r"(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_") def _session_key(db_path: str) -> tuple[str, str] | None: """Extract (session_date, session_time) from a tracking DB filename.""" if not isinstance(db_path, str) or not db_path: return None m = _DB_TIMESTAMP_RE.search(Path(db_path).name) return (m.group(1), m.group(2)) if m else None def _load_barrier_lookup(csv_path: Path) -> dict[tuple[str, str, str], dict]: """Build (machine, session_date, session_time) → opening/bad_rois lookup. Returns an empty dict if the CSV is missing — callers should treat that as "no per-session annotations available" rather than an error. """ if not Path(csv_path).exists(): return {} df = pd.read_csv(csv_path) lookup: dict[tuple[str, str, str], dict] = {} for r in df.itertuples(index=False): bad = set() if isinstance(r.bad_rois, str) and r.bad_rois.strip(): bad = {int(x) for x in r.bad_rois.split(",") if x.strip()} lookup[(r.machine_name, r.session_date, r.session_time)] = { "opening_s": float(r.opening_s) if pd.notna(r.opening_s) else float("nan"), "trim_first_s": float(r.trim_first_s) if pd.notna(r.trim_first_s) else 0.0, "bad_rois": bad, "unusable": pd.isna(r.opening_s), } return lookup # Reason: prefer the explicit Jupyter-widget tqdm when available (it # updates reliably in JupyterLab, where text \r-style bars sometimes # don't refresh in-place). Fall back to tqdm.auto, then to a no-op. try: from tqdm.notebook import tqdm except ImportError: try: from tqdm.auto import tqdm except ImportError: def tqdm(*_args, **_kwargs): # type: ignore[no-redef] class _NoOpBar: def update(self, _n=1): pass def set_postfix_str(self, _s): pass def close(self): pass return _NoOpBar() # Metadata columns to copy onto every tracking sample. These are the xlsx # fields that describe the experimental condition behind each fly/ROI. # Reason: the ROI column is uppercase ("ROI") for backwards compatibility # with the existing analysis pipeline (calculate_distances.py, notebooks). _META_COLS = ( "date", "machine_name", "species", "male", "training_date_time", "testing_date_time", "training_length_hr", "consolidation_length_hr", "memory", "age", ) def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None: """Cached read-only sqlite connection. Returns None on failure.""" if not isinstance(db_path, str) or not db_path: return None if db_path not in cache: try: cache[db_path] = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) except sqlite3.Error as e: print(f"failed to open {Path(db_path).name}: {e}") cache[db_path] = None return cache[db_path] def load_roi_data( meta: pd.DataFrame | None = None, progress: bool = True, apply_barrier_filter: bool = True, ) -> pd.DataFrame: """Load ROI tracking data joined with experimental metadata. For each row in ``meta``, reads the matching ROI table from both the training DB and the testing DB (whichever exist), and stamps every sample with the row's metadata plus a ``session`` column (``"training"`` or ``"testing"``). Rows with empty DB paths (unusable videos, or videos that didn't pass the completeness gate) are skipped. Both training and testing reads are filtered against ``barrier_opening.csv`` (the picker annotates both video types): flies whose ROI never released (listed in ``bad_rois``) and entire sessions flagged unusable are dropped. The session's ``opening_s`` is stamped onto its samples so downstream code can compute ``t_from_opening = t - opening_s``. Sessions missing from the CSV are still loaded, but with ``opening_s = NaN``. Args: meta: optional DataFrame with the same schema as ``all_video_info_merged.tsv``. Pass a filtered slice to load a subset (e.g. ``meta[meta.species == 'Melanogaster/CS']``). Defaults to the full TSV. progress: show a tqdm progress bar (one tick per fly/ROI row). Defaults to True. Set False for silent batch jobs. apply_barrier_filter: if True (default), drop session data for flies whose barrier never opened and stamp ``opening_s`` onto every sample. Set False to load raw data without any barrier-derived filtering or columns. Returns: DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred, has_interacted, session, ROI, opening_s, `` — one row per tracking sample. ``opening_s`` is NaN for sessions not covered by ``barrier_opening.csv``. Empty if nothing could be loaded. """ if meta is None: meta = pd.read_csv(VIDEO_INFO_TSV, sep="\t") # Honor the per-row `include` flag if the TSV has one. Rows with # include=False are dropped (typically too-noisy videos the analyst # has marked out). Missing column → load everything (back-compat). if "include" in meta.columns: meta = meta[meta["include"].astype(bool)] barrier_lookup = _load_barrier_lookup(BARRIER_OPENING_CSV) if apply_barrier_filter else {} db_cache: dict = {} chunks: list[pd.DataFrame] = [] n_skipped_bad_roi = 0 n_skipped_unusable = 0 n_rows = len(meta) if progress: # Reason: this is a slow operation (one SQL query per session per # ROI; the full batch is ~minutes). Print up front so the user # knows whether to grab a coffee. print( f"Loading ROI data for {n_rows} flies × 2 sessions " f"({2 * n_rows} DB queries). This typically takes 2–5 minutes.", flush=True, ) # Reason: tick per-session (2 per fly) instead of per-fly so the bar # advances roughly every second, not every ~2s. set_postfix_str shows # what's being processed — gives the user something visibly changing # even when total ticks are slow. pbar = tqdm( total=2 * n_rows, desc="loading flies", unit="session", disable=not progress, mininterval=0.5, ) for row in meta.itertuples(index=False): for session in ("training", "testing"): pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}") db_path = getattr(row, f"{session}_db_path") # The picker annotates barrier_opening per video, and both # the training and testing videos have their own entries. # Apply the same per-session filter to both. opening_s = float("nan") if barrier_lookup: key = _session_key(db_path) if key is not None: bo = barrier_lookup.get((row.machine_name, key[0], key[1])) if bo is not None: if bo["unusable"]: n_skipped_unusable += 1 pbar.update(1) continue if int(row.roi) in bo["bad_rois"]: n_skipped_bad_roi += 1 pbar.update(1) continue opening_s = bo["opening_s"] conn = _open_ro(db_path, db_cache) if conn is None: pbar.update(1) continue try: df = pd.read_sql_query( f"SELECT * FROM ROI_{int(row.roi)}", conn ) except Exception as e: # Reason: a DB may be missing a ROI table if tracking was # partial — skip rather than abort the whole batch. print(f" ROI_{row.roi} from {session} DB: {e}") pbar.update(1) continue df["session"] = session df["ROI"] = int(row.roi) df["opening_s"] = opening_s for col in _META_COLS: df[col] = getattr(row, col) chunks.append(df) pbar.update(1) pbar.close() if apply_barrier_filter and (n_skipped_bad_roi or n_skipped_unusable): print( f"Barrier filter: dropped {n_skipped_bad_roi} ROI loads (barrier " f"never opened) and {n_skipped_unusable} unusable sessions.", flush=True, ) for conn in db_cache.values(): if conn is not None: conn.close() return pd.concat(chunks, ignore_index=True) if chunks else pd.DataFrame() if __name__ == "__main__": data = load_roi_data() print(f"shape: {data.shape}") if not data.empty: print(f"columns: {list(data.columns)}") print(f"sessions: {data['session'].value_counts().to_dict()}") print(f"unique machines: {data['machine_name'].nunique()}") print( f"unique flies (date,machine,roi): " f"{data.groupby(['date','machine_name','roi']).ngroups}" )