diff --git a/requirements.txt b/requirements.txt index 3fa3d16..142146a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ seaborn>=0.12 scipy>=1.10 scikit-learn>=1.3 jupyter>=1.0 +tqdm>=4.66 +ipywidgets>=8.0 diff --git a/scripts/load_roi_data.py b/scripts/load_roi_data.py index 82381b3..9af8a09 100644 --- a/scripts/load_roi_data.py +++ b/scripts/load_roi_data.py @@ -15,6 +15,15 @@ import pandas as pd from config import VIDEO_INFO_TSV +# Reason: tqdm.auto picks the right backend automatically — Jupyter widget +# inside a notebook, plain text on the CLI. Fall back to a no-op wrapper +# if tqdm isn't installed so the loader still works in minimal environments. +try: + from tqdm.auto import tqdm +except ImportError: + def tqdm(iterable, **_kwargs): # type: ignore[no-redef] + return iterable + # Metadata columns to copy onto every tracking sample. These are the xlsx # fields that describe the experimental condition behind each fly/ROI. @@ -47,7 +56,10 @@ def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None: return cache[db_path] -def load_roi_data(meta: pd.DataFrame | None = None) -> pd.DataFrame: +def load_roi_data( + meta: pd.DataFrame | None = None, + progress: bool = True, +) -> pd.DataFrame: """Load ROI tracking data joined with experimental metadata. For each row in ``meta``, reads the matching ROI table from both the @@ -61,6 +73,8 @@ def load_roi_data(meta: pd.DataFrame | None = None) -> pd.DataFrame: ``all_video_info_merged.tsv``. Pass a filtered slice to load a subset (e.g. ``meta[meta.species == 'Melanogaster/CS']``). Defaults to the full TSV. + progress: show a tqdm progress bar (one tick per fly/ROI row). + Defaults to True. Set False for silent batch jobs. Returns: DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred, @@ -79,7 +93,24 @@ def load_roi_data(meta: pd.DataFrame | None = None) -> pd.DataFrame: db_cache: dict = {} chunks: list[pd.DataFrame] = [] - for row in meta.itertuples(index=False): + n_rows = len(meta) + if progress: + # Reason: this is a slow operation (one SQL query per session per + # ROI; the full batch is ~minutes). Print up front so the user + # knows whether to grab a coffee. + print( + f"Loading ROI data for {n_rows} flies × 2 sessions " + f"({2 * n_rows} DB queries). This typically takes 1–3 minutes." + ) + iterator = tqdm( + meta.itertuples(index=False), + total=n_rows, + desc="loading flies", + unit="fly", + disable=not progress, + ) + + for row in iterator: for session in ("training", "testing"): conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache) if conn is None: