diff --git a/scripts/load_roi_data.py b/scripts/load_roi_data.py index 9af8a09..ee2263c 100644 --- a/scripts/load_roi_data.py +++ b/scripts/load_roi_data.py @@ -15,14 +15,21 @@ import pandas as pd from config import VIDEO_INFO_TSV -# Reason: tqdm.auto picks the right backend automatically — Jupyter widget -# inside a notebook, plain text on the CLI. Fall back to a no-op wrapper -# if tqdm isn't installed so the loader still works in minimal environments. +# Reason: prefer the explicit Jupyter-widget tqdm when available (it +# updates reliably in JupyterLab, where text \r-style bars sometimes +# don't refresh in-place). Fall back to tqdm.auto, then to a no-op. try: - from tqdm.auto import tqdm + from tqdm.notebook import tqdm except ImportError: - def tqdm(iterable, **_kwargs): # type: ignore[no-redef] - return iterable + try: + from tqdm.auto import tqdm + except ImportError: + def tqdm(*_args, **_kwargs): # type: ignore[no-redef] + class _NoOpBar: + def update(self, _n=1): pass + def set_postfix_str(self, _s): pass + def close(self): pass + return _NoOpBar() # Metadata columns to copy onto every tracking sample. These are the xlsx @@ -100,20 +107,27 @@ def load_roi_data( # knows whether to grab a coffee. print( f"Loading ROI data for {n_rows} flies × 2 sessions " - f"({2 * n_rows} DB queries). This typically takes 1–3 minutes." + f"({2 * n_rows} DB queries). This typically takes 2–5 minutes.", + flush=True, ) - iterator = tqdm( - meta.itertuples(index=False), - total=n_rows, + # Reason: tick per-session (2 per fly) instead of per-fly so the bar + # advances roughly every second, not every ~2s. set_postfix_str shows + # what's being processed — gives the user something visibly changing + # even when total ticks are slow. + pbar = tqdm( + total=2 * n_rows, desc="loading flies", - unit="fly", + unit="session", disable=not progress, + mininterval=0.5, ) - for row in iterator: + for row in meta.itertuples(index=False): for session in ("training", "testing"): + pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}") conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache) if conn is None: + pbar.update(1) continue try: df = pd.read_sql_query( @@ -123,12 +137,16 @@ def load_roi_data( # Reason: a DB may be missing a ROI table if tracking was # partial — skip rather than abort the whole batch. print(f" ROI_{row.roi} from {session} DB: {e}") + pbar.update(1) continue df["session"] = session df["ROI"] = int(row.roi) for col in _META_COLS: df[col] = getattr(row, col) chunks.append(df) + pbar.update(1) + + pbar.close() for conn in db_cache.values(): if conn is not None: