Make load_roi_data progress bar refresh reliably in JupyterLab

Prefer tqdm.notebook (HTML widget) over tqdm.auto so JupyterLab gets a
proper updating bar even when its text-mode \r refresh doesn't render
in-place. Tick per session (2× per fly) instead of per fly so the bar
advances roughly every second, and add a postfix showing the current
machine + ROI + session — gives visible motion even on slow rows.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Giorgio Gilestro 2026-05-01 09:43:12 +01:00
parent 8abb3d5955
commit b273255dea

View file

@ -15,14 +15,21 @@ import pandas as pd
from config import VIDEO_INFO_TSV from config import VIDEO_INFO_TSV
# Reason: tqdm.auto picks the right backend automatically — Jupyter widget # Reason: prefer the explicit Jupyter-widget tqdm when available (it
# inside a notebook, plain text on the CLI. Fall back to a no-op wrapper # updates reliably in JupyterLab, where text \r-style bars sometimes
# if tqdm isn't installed so the loader still works in minimal environments. # don't refresh in-place). Fall back to tqdm.auto, then to a no-op.
try:
from tqdm.notebook import tqdm
except ImportError:
try: try:
from tqdm.auto import tqdm from tqdm.auto import tqdm
except ImportError: except ImportError:
def tqdm(iterable, **_kwargs): # type: ignore[no-redef] def tqdm(*_args, **_kwargs): # type: ignore[no-redef]
return iterable class _NoOpBar:
def update(self, _n=1): pass
def set_postfix_str(self, _s): pass
def close(self): pass
return _NoOpBar()
# Metadata columns to copy onto every tracking sample. These are the xlsx # Metadata columns to copy onto every tracking sample. These are the xlsx
@ -100,20 +107,27 @@ def load_roi_data(
# knows whether to grab a coffee. # knows whether to grab a coffee.
print( print(
f"Loading ROI data for {n_rows} flies × 2 sessions " f"Loading ROI data for {n_rows} flies × 2 sessions "
f"({2 * n_rows} DB queries). This typically takes 13 minutes." f"({2 * n_rows} DB queries). This typically takes 25 minutes.",
flush=True,
) )
iterator = tqdm( # Reason: tick per-session (2 per fly) instead of per-fly so the bar
meta.itertuples(index=False), # advances roughly every second, not every ~2s. set_postfix_str shows
total=n_rows, # what's being processed — gives the user something visibly changing
# even when total ticks are slow.
pbar = tqdm(
total=2 * n_rows,
desc="loading flies", desc="loading flies",
unit="fly", unit="session",
disable=not progress, disable=not progress,
mininterval=0.5,
) )
for row in iterator: for row in meta.itertuples(index=False):
for session in ("training", "testing"): for session in ("training", "testing"):
pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}")
conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache) conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache)
if conn is None: if conn is None:
pbar.update(1)
continue continue
try: try:
df = pd.read_sql_query( df = pd.read_sql_query(
@ -123,12 +137,16 @@ def load_roi_data(
# Reason: a DB may be missing a ROI table if tracking was # Reason: a DB may be missing a ROI table if tracking was
# partial — skip rather than abort the whole batch. # partial — skip rather than abort the whole batch.
print(f" ROI_{row.roi} from {session} DB: {e}") print(f" ROI_{row.roi} from {session} DB: {e}")
pbar.update(1)
continue continue
df["session"] = session df["session"] = session
df["ROI"] = int(row.roi) df["ROI"] = int(row.roi)
for col in _META_COLS: for col in _META_COLS:
df[col] = getattr(row, col) df[col] = getattr(row, col)
chunks.append(df) chunks.append(df)
pbar.update(1)
pbar.close()
for conn in db_cache.values(): for conn in db_cache.values():
if conn is not None: if conn is not None: