cupido/scripts/load_roi_data.py
Giorgio Gilestro 8abb3d5955 Add tqdm progress bar to load_roi_data
Loading the full batch issues 968 SQL queries and takes minutes — show
a tqdm progress bar (one tick per fly/ROI row) and print an upfront
"this takes 1-3 minutes" notice so the user knows to wait. Uses
tqdm.auto so it picks the Jupyter widget when run from a notebook and
plain text on the CLI. New `progress=True` parameter on load_roi_data,
flip to False for silent batch use. tqdm + ipywidgets added to
requirements.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-01 09:34:42 +01:00

150 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Load ROI tracking data from all sessions into one DataFrame.
Drives off the merged TSV (one row per ROI/fly across training + testing
phases). For each TSV row, opens the corresponding tracking DB and pulls
the matching ROI table, then attaches the experimental metadata.
The TSV is the single source of truth for what data exists and how it
maps to flies and conditions.
"""
import sqlite3
from pathlib import Path
import pandas as pd
from config import VIDEO_INFO_TSV
# Reason: tqdm.auto picks the right backend automatically — Jupyter widget
# inside a notebook, plain text on the CLI. Fall back to a no-op wrapper
# if tqdm isn't installed so the loader still works in minimal environments.
try:
from tqdm.auto import tqdm
except ImportError:
def tqdm(iterable, **_kwargs): # type: ignore[no-redef]
return iterable
# Metadata columns to copy onto every tracking sample. These are the xlsx
# fields that describe the experimental condition behind each fly/ROI.
# Reason: the ROI column is uppercase ("ROI") for backwards compatibility
# with the existing analysis pipeline (calculate_distances.py, notebooks).
_META_COLS = (
"date",
"machine_name",
"species",
"male",
"training_date_time",
"testing_date_time",
"training_length_hr",
"consolidation_length_hr",
"memory",
"age",
)
def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None:
"""Cached read-only sqlite connection. Returns None on failure."""
if not isinstance(db_path, str) or not db_path:
return None
if db_path not in cache:
try:
cache[db_path] = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
except sqlite3.Error as e:
print(f"failed to open {Path(db_path).name}: {e}")
cache[db_path] = None
return cache[db_path]
def load_roi_data(
meta: pd.DataFrame | None = None,
progress: bool = True,
) -> pd.DataFrame:
"""Load ROI tracking data joined with experimental metadata.
For each row in ``meta``, reads the matching ROI table from both the
training DB and the testing DB (whichever exist), and stamps every
sample with the row's metadata plus a ``session`` column
(``"training"`` or ``"testing"``). Rows with empty DB paths (unusable
videos, or videos that didn't pass the completeness gate) are skipped.
Args:
meta: optional DataFrame with the same schema as
``all_video_info_merged.tsv``. Pass a filtered slice to load a
subset (e.g. ``meta[meta.species == 'Melanogaster/CS']``).
Defaults to the full TSV.
progress: show a tqdm progress bar (one tick per fly/ROI row).
Defaults to True. Set False for silent batch jobs.
Returns:
DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred,
has_interacted, session, <metadata>`` — one row per tracking
sample. Empty if nothing could be loaded.
"""
if meta is None:
meta = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
# Honor the per-row `include` flag if the TSV has one. Rows with
# include=False are dropped (typically too-noisy videos the analyst
# has marked out). Missing column → load everything (back-compat).
if "include" in meta.columns:
meta = meta[meta["include"].astype(bool)]
db_cache: dict = {}
chunks: list[pd.DataFrame] = []
n_rows = len(meta)
if progress:
# Reason: this is a slow operation (one SQL query per session per
# ROI; the full batch is ~minutes). Print up front so the user
# knows whether to grab a coffee.
print(
f"Loading ROI data for {n_rows} flies × 2 sessions "
f"({2 * n_rows} DB queries). This typically takes 13 minutes."
)
iterator = tqdm(
meta.itertuples(index=False),
total=n_rows,
desc="loading flies",
unit="fly",
disable=not progress,
)
for row in iterator:
for session in ("training", "testing"):
conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache)
if conn is None:
continue
try:
df = pd.read_sql_query(
f"SELECT * FROM ROI_{int(row.roi)}", conn
)
except Exception as e:
# Reason: a DB may be missing a ROI table if tracking was
# partial — skip rather than abort the whole batch.
print(f" ROI_{row.roi} from {session} DB: {e}")
continue
df["session"] = session
df["ROI"] = int(row.roi)
for col in _META_COLS:
df[col] = getattr(row, col)
chunks.append(df)
for conn in db_cache.values():
if conn is not None:
conn.close()
return pd.concat(chunks, ignore_index=True) if chunks else pd.DataFrame()
if __name__ == "__main__":
data = load_roi_data()
print(f"shape: {data.shape}")
if not data.empty:
print(f"columns: {list(data.columns)}")
print(f"sessions: {data['session'].value_counts().to_dict()}")
print(f"unique machines: {data['machine_name'].nunique()}")
print(
f"unique flies (date,machine,roi): "
f"{data.groupby(['date','machine_name','roi']).ngroups}"
)