Add tqdm progress bar to load_roi_data
Loading the full batch issues 968 SQL queries and takes minutes — show a tqdm progress bar (one tick per fly/ROI row) and print an upfront "this takes 1-3 minutes" notice so the user knows to wait. Uses tqdm.auto so it picks the Jupyter widget when run from a notebook and plain text on the CLI. New `progress=True` parameter on load_roi_data, flip to False for silent batch use. tqdm + ipywidgets added to requirements. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
ac3b8c13f0
commit
8abb3d5955
2 changed files with 35 additions and 2 deletions
|
|
@ -5,3 +5,5 @@ seaborn>=0.12
|
||||||
scipy>=1.10
|
scipy>=1.10
|
||||||
scikit-learn>=1.3
|
scikit-learn>=1.3
|
||||||
jupyter>=1.0
|
jupyter>=1.0
|
||||||
|
tqdm>=4.66
|
||||||
|
ipywidgets>=8.0
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,15 @@ import pandas as pd
|
||||||
|
|
||||||
from config import VIDEO_INFO_TSV
|
from config import VIDEO_INFO_TSV
|
||||||
|
|
||||||
|
# Reason: tqdm.auto picks the right backend automatically — Jupyter widget
|
||||||
|
# inside a notebook, plain text on the CLI. Fall back to a no-op wrapper
|
||||||
|
# if tqdm isn't installed so the loader still works in minimal environments.
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
except ImportError:
|
||||||
|
def tqdm(iterable, **_kwargs): # type: ignore[no-redef]
|
||||||
|
return iterable
|
||||||
|
|
||||||
|
|
||||||
# Metadata columns to copy onto every tracking sample. These are the xlsx
|
# Metadata columns to copy onto every tracking sample. These are the xlsx
|
||||||
# fields that describe the experimental condition behind each fly/ROI.
|
# fields that describe the experimental condition behind each fly/ROI.
|
||||||
|
|
@ -47,7 +56,10 @@ def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None:
|
||||||
return cache[db_path]
|
return cache[db_path]
|
||||||
|
|
||||||
|
|
||||||
def load_roi_data(meta: pd.DataFrame | None = None) -> pd.DataFrame:
|
def load_roi_data(
|
||||||
|
meta: pd.DataFrame | None = None,
|
||||||
|
progress: bool = True,
|
||||||
|
) -> pd.DataFrame:
|
||||||
"""Load ROI tracking data joined with experimental metadata.
|
"""Load ROI tracking data joined with experimental metadata.
|
||||||
|
|
||||||
For each row in ``meta``, reads the matching ROI table from both the
|
For each row in ``meta``, reads the matching ROI table from both the
|
||||||
|
|
@ -61,6 +73,8 @@ def load_roi_data(meta: pd.DataFrame | None = None) -> pd.DataFrame:
|
||||||
``all_video_info_merged.tsv``. Pass a filtered slice to load a
|
``all_video_info_merged.tsv``. Pass a filtered slice to load a
|
||||||
subset (e.g. ``meta[meta.species == 'Melanogaster/CS']``).
|
subset (e.g. ``meta[meta.species == 'Melanogaster/CS']``).
|
||||||
Defaults to the full TSV.
|
Defaults to the full TSV.
|
||||||
|
progress: show a tqdm progress bar (one tick per fly/ROI row).
|
||||||
|
Defaults to True. Set False for silent batch jobs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred,
|
DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred,
|
||||||
|
|
@ -79,7 +93,24 @@ def load_roi_data(meta: pd.DataFrame | None = None) -> pd.DataFrame:
|
||||||
db_cache: dict = {}
|
db_cache: dict = {}
|
||||||
chunks: list[pd.DataFrame] = []
|
chunks: list[pd.DataFrame] = []
|
||||||
|
|
||||||
for row in meta.itertuples(index=False):
|
n_rows = len(meta)
|
||||||
|
if progress:
|
||||||
|
# Reason: this is a slow operation (one SQL query per session per
|
||||||
|
# ROI; the full batch is ~minutes). Print up front so the user
|
||||||
|
# knows whether to grab a coffee.
|
||||||
|
print(
|
||||||
|
f"Loading ROI data for {n_rows} flies × 2 sessions "
|
||||||
|
f"({2 * n_rows} DB queries). This typically takes 1–3 minutes."
|
||||||
|
)
|
||||||
|
iterator = tqdm(
|
||||||
|
meta.itertuples(index=False),
|
||||||
|
total=n_rows,
|
||||||
|
desc="loading flies",
|
||||||
|
unit="fly",
|
||||||
|
disable=not progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in iterator:
|
||||||
for session in ("training", "testing"):
|
for session in ("training", "testing"):
|
||||||
conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache)
|
conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache)
|
||||||
if conn is None:
|
if conn is None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue