For every session (training and testing alike), the loader now looks
up the corresponding row in barrier_opening.csv and:
- drops the read if the ROI is in bad_rois (barrier never opened
for that fly so its tracking has no biological meaning)
- drops the read if the session is flagged unusable
- stamps the session's opening_s onto every sample so downstream
code can compute t_from_opening = t - opening_s
Tested against ETHOSCOPE_082 2024-09-17: training (bad_rois=1,3,5)
correctly drops ROIs 1/3/5; testing keeps all six; opening_s differs
between sessions as expected (646.8 vs 154.7).
Opt out with apply_barrier_filter=False if you need raw data.
252 lines
9.6 KiB
Python
252 lines
9.6 KiB
Python
"""Load ROI tracking data from all sessions into one DataFrame.
|
||
|
||
Drives off the merged TSV (one row per ROI/fly across training + testing
|
||
phases). For each TSV row, opens the corresponding tracking DB and pulls
|
||
the matching ROI table, then attaches the experimental metadata.
|
||
|
||
The TSV is the single source of truth for what data exists and how it
|
||
maps to flies and conditions.
|
||
"""
|
||
|
||
import re
|
||
import sqlite3
|
||
from pathlib import Path
|
||
|
||
import pandas as pd
|
||
|
||
from config import BARRIER_OPENING_CSV, VIDEO_INFO_TSV
|
||
|
||
# DB filenames start with `YYYY-MM-DD_HH-MM-SS_<uuid>_...` — pull the
|
||
# session date/time out so we can join against barrier_opening.csv.
|
||
_DB_TIMESTAMP_RE = re.compile(r"(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_")
|
||
|
||
|
||
def _session_key(db_path: str) -> tuple[str, str] | None:
|
||
"""Extract (session_date, session_time) from a tracking DB filename."""
|
||
if not isinstance(db_path, str) or not db_path:
|
||
return None
|
||
m = _DB_TIMESTAMP_RE.search(Path(db_path).name)
|
||
return (m.group(1), m.group(2)) if m else None
|
||
|
||
|
||
def _load_barrier_lookup(csv_path: Path) -> dict[tuple[str, str, str], dict]:
|
||
"""Build (machine, session_date, session_time) → opening/bad_rois lookup.
|
||
|
||
Returns an empty dict if the CSV is missing — callers should treat
|
||
that as "no per-session annotations available" rather than an error.
|
||
"""
|
||
if not Path(csv_path).exists():
|
||
return {}
|
||
df = pd.read_csv(csv_path)
|
||
lookup: dict[tuple[str, str, str], dict] = {}
|
||
for r in df.itertuples(index=False):
|
||
bad = set()
|
||
if isinstance(r.bad_rois, str) and r.bad_rois.strip():
|
||
bad = {int(x) for x in r.bad_rois.split(",") if x.strip()}
|
||
lookup[(r.machine_name, r.session_date, r.session_time)] = {
|
||
"opening_s": float(r.opening_s) if pd.notna(r.opening_s) else float("nan"),
|
||
"trim_first_s": float(r.trim_first_s) if pd.notna(r.trim_first_s) else 0.0,
|
||
"bad_rois": bad,
|
||
"unusable": pd.isna(r.opening_s),
|
||
}
|
||
return lookup
|
||
|
||
# Reason: prefer the explicit Jupyter-widget tqdm when available (it
|
||
# updates reliably in JupyterLab, where text \r-style bars sometimes
|
||
# don't refresh in-place). Fall back to tqdm.auto, then to a no-op.
|
||
try:
|
||
from tqdm.notebook import tqdm
|
||
except ImportError:
|
||
try:
|
||
from tqdm.auto import tqdm
|
||
except ImportError:
|
||
def tqdm(*_args, **_kwargs): # type: ignore[no-redef]
|
||
class _NoOpBar:
|
||
def update(self, _n=1): pass
|
||
def set_postfix_str(self, _s): pass
|
||
def close(self): pass
|
||
return _NoOpBar()
|
||
|
||
|
||
# Metadata columns to copy onto every tracking sample. These are the xlsx
|
||
# fields that describe the experimental condition behind each fly/ROI.
|
||
# Reason: the ROI column is uppercase ("ROI") for backwards compatibility
|
||
# with the existing analysis pipeline (calculate_distances.py, notebooks).
|
||
_META_COLS = (
|
||
"date",
|
||
"machine_name",
|
||
"species",
|
||
"male",
|
||
"training_date_time",
|
||
"testing_date_time",
|
||
"training_length_hr",
|
||
"consolidation_length_hr",
|
||
"memory",
|
||
"age",
|
||
)
|
||
|
||
|
||
def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None:
|
||
"""Cached read-only sqlite connection. Returns None on failure."""
|
||
if not isinstance(db_path, str) or not db_path:
|
||
return None
|
||
if db_path not in cache:
|
||
try:
|
||
cache[db_path] = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||
except sqlite3.Error as e:
|
||
print(f"failed to open {Path(db_path).name}: {e}")
|
||
cache[db_path] = None
|
||
return cache[db_path]
|
||
|
||
|
||
def load_roi_data(
|
||
meta: pd.DataFrame | None = None,
|
||
progress: bool = True,
|
||
apply_barrier_filter: bool = True,
|
||
) -> pd.DataFrame:
|
||
"""Load ROI tracking data joined with experimental metadata.
|
||
|
||
For each row in ``meta``, reads the matching ROI table from both the
|
||
training DB and the testing DB (whichever exist), and stamps every
|
||
sample with the row's metadata plus a ``session`` column
|
||
(``"training"`` or ``"testing"``). Rows with empty DB paths (unusable
|
||
videos, or videos that didn't pass the completeness gate) are skipped.
|
||
|
||
Both training and testing reads are filtered against
|
||
``barrier_opening.csv`` (the picker annotates both video types):
|
||
flies whose ROI never released (listed in ``bad_rois``) and entire
|
||
sessions flagged unusable are dropped. The session's ``opening_s``
|
||
is stamped onto its samples so downstream code can compute
|
||
``t_from_opening = t - opening_s``. Sessions missing from the CSV
|
||
are still loaded, but with ``opening_s = NaN``.
|
||
|
||
Args:
|
||
meta: optional DataFrame with the same schema as
|
||
``all_video_info_merged.tsv``. Pass a filtered slice to load a
|
||
subset (e.g. ``meta[meta.species == 'Melanogaster/CS']``).
|
||
Defaults to the full TSV.
|
||
progress: show a tqdm progress bar (one tick per fly/ROI row).
|
||
Defaults to True. Set False for silent batch jobs.
|
||
apply_barrier_filter: if True (default), drop session data for
|
||
flies whose barrier never opened and stamp ``opening_s``
|
||
onto every sample. Set False to load raw data without any
|
||
barrier-derived filtering or columns.
|
||
|
||
Returns:
|
||
DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred,
|
||
has_interacted, session, ROI, opening_s, <metadata>`` — one row
|
||
per tracking sample. ``opening_s`` is NaN for sessions not
|
||
covered by ``barrier_opening.csv``. Empty if nothing could be
|
||
loaded.
|
||
"""
|
||
if meta is None:
|
||
meta = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
|
||
|
||
# Honor the per-row `include` flag if the TSV has one. Rows with
|
||
# include=False are dropped (typically too-noisy videos the analyst
|
||
# has marked out). Missing column → load everything (back-compat).
|
||
if "include" in meta.columns:
|
||
meta = meta[meta["include"].astype(bool)]
|
||
|
||
barrier_lookup = _load_barrier_lookup(BARRIER_OPENING_CSV) if apply_barrier_filter else {}
|
||
|
||
db_cache: dict = {}
|
||
chunks: list[pd.DataFrame] = []
|
||
n_skipped_bad_roi = 0
|
||
n_skipped_unusable = 0
|
||
|
||
n_rows = len(meta)
|
||
if progress:
|
||
# Reason: this is a slow operation (one SQL query per session per
|
||
# ROI; the full batch is ~minutes). Print up front so the user
|
||
# knows whether to grab a coffee.
|
||
print(
|
||
f"Loading ROI data for {n_rows} flies × 2 sessions "
|
||
f"({2 * n_rows} DB queries). This typically takes 2–5 minutes.",
|
||
flush=True,
|
||
)
|
||
# Reason: tick per-session (2 per fly) instead of per-fly so the bar
|
||
# advances roughly every second, not every ~2s. set_postfix_str shows
|
||
# what's being processed — gives the user something visibly changing
|
||
# even when total ticks are slow.
|
||
pbar = tqdm(
|
||
total=2 * n_rows,
|
||
desc="loading flies",
|
||
unit="session",
|
||
disable=not progress,
|
||
mininterval=0.5,
|
||
)
|
||
|
||
for row in meta.itertuples(index=False):
|
||
for session in ("training", "testing"):
|
||
pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}")
|
||
db_path = getattr(row, f"{session}_db_path")
|
||
|
||
# The picker annotates barrier_opening per video, and both
|
||
# the training and testing videos have their own entries.
|
||
# Apply the same per-session filter to both.
|
||
opening_s = float("nan")
|
||
if barrier_lookup:
|
||
key = _session_key(db_path)
|
||
if key is not None:
|
||
bo = barrier_lookup.get((row.machine_name, key[0], key[1]))
|
||
if bo is not None:
|
||
if bo["unusable"]:
|
||
n_skipped_unusable += 1
|
||
pbar.update(1)
|
||
continue
|
||
if int(row.roi) in bo["bad_rois"]:
|
||
n_skipped_bad_roi += 1
|
||
pbar.update(1)
|
||
continue
|
||
opening_s = bo["opening_s"]
|
||
|
||
conn = _open_ro(db_path, db_cache)
|
||
if conn is None:
|
||
pbar.update(1)
|
||
continue
|
||
try:
|
||
df = pd.read_sql_query(
|
||
f"SELECT * FROM ROI_{int(row.roi)}", conn
|
||
)
|
||
except Exception as e:
|
||
# Reason: a DB may be missing a ROI table if tracking was
|
||
# partial — skip rather than abort the whole batch.
|
||
print(f" ROI_{row.roi} from {session} DB: {e}")
|
||
pbar.update(1)
|
||
continue
|
||
df["session"] = session
|
||
df["ROI"] = int(row.roi)
|
||
df["opening_s"] = opening_s
|
||
for col in _META_COLS:
|
||
df[col] = getattr(row, col)
|
||
chunks.append(df)
|
||
pbar.update(1)
|
||
|
||
pbar.close()
|
||
|
||
if apply_barrier_filter and (n_skipped_bad_roi or n_skipped_unusable):
|
||
print(
|
||
f"Barrier filter: dropped {n_skipped_bad_roi} ROI loads (barrier "
|
||
f"never opened) and {n_skipped_unusable} unusable sessions.",
|
||
flush=True,
|
||
)
|
||
|
||
for conn in db_cache.values():
|
||
if conn is not None:
|
||
conn.close()
|
||
|
||
return pd.concat(chunks, ignore_index=True) if chunks else pd.DataFrame()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
data = load_roi_data()
|
||
print(f"shape: {data.shape}")
|
||
if not data.empty:
|
||
print(f"columns: {list(data.columns)}")
|
||
print(f"sessions: {data['session'].value_counts().to_dict()}")
|
||
print(f"unique machines: {data['machine_name'].nunique()}")
|
||
print(
|
||
f"unique flies (date,machine,roi): "
|
||
f"{data.groupby(['date','machine_name','roi']).ngroups}"
|
||
)
|