load_roi_data: filter on barrier_opening.csv and stamp opening_s
For every session (training and testing alike), the loader now looks
up the corresponding row in barrier_opening.csv and:
- drops the read if the ROI is in bad_rois (barrier never opened
for that fly so its tracking has no biological meaning)
- drops the read if the session is flagged unusable
- stamps the session's opening_s onto every sample so downstream
code can compute t_from_opening = t - opening_s
Tested against ETHOSCOPE_082 2024-09-17: training (bad_rois=1,3,5)
correctly drops ROIs 1/3/5; testing keeps all six; opening_s differs
between sessions as expected (646.8 vs 154.7).
Opt out with apply_barrier_filter=False if you need raw data.
This commit is contained in:
parent
b8f23a4884
commit
28b7a227c0
2 changed files with 93 additions and 4 deletions
|
|
@ -26,6 +26,11 @@ VIDEO_INFO_TSV = DATA_VOLUME / "all_video_info_merged.tsv"
|
||||||
# A small CSV listing every video file we know about (built locally).
|
# A small CSV listing every video file we know about (built locally).
|
||||||
INVENTORY_CSV = DATA_METADATA / "video_inventory.csv"
|
INVENTORY_CSV = DATA_METADATA / "video_inventory.csv"
|
||||||
|
|
||||||
|
# Hand-annotated barrier-opening times (output of the picker app). One
|
||||||
|
# row per testing session; columns: machine_name, session_date,
|
||||||
|
# session_time, opening_s, trim_first_s, bad_rois, analyst, notes.
|
||||||
|
BARRIER_OPENING_CSV = DATA_METADATA / "barrier_opening.csv"
|
||||||
|
|
||||||
# Where the ethoscope source tree is checked out (used by track_videos.py
|
# Where the ethoscope source tree is checked out (used by track_videos.py
|
||||||
# and auto_detect_targets.py — host-side scripts that import ethoscope
|
# and auto_detect_targets.py — host-side scripts that import ethoscope
|
||||||
# from a local clone rather than from pip). Default assumes the standard
|
# from a local clone rather than from pip). Default assumes the standard
|
||||||
|
|
|
||||||
|
|
@ -8,12 +8,48 @@ The TSV is the single source of truth for what data exists and how it
|
||||||
maps to flies and conditions.
|
maps to flies and conditions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from config import VIDEO_INFO_TSV
|
from config import BARRIER_OPENING_CSV, VIDEO_INFO_TSV
|
||||||
|
|
||||||
|
# DB filenames start with `YYYY-MM-DD_HH-MM-SS_<uuid>_...` — pull the
|
||||||
|
# session date/time out so we can join against barrier_opening.csv.
|
||||||
|
_DB_TIMESTAMP_RE = re.compile(r"(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_")
|
||||||
|
|
||||||
|
|
||||||
|
def _session_key(db_path: str) -> tuple[str, str] | None:
|
||||||
|
"""Extract (session_date, session_time) from a tracking DB filename."""
|
||||||
|
if not isinstance(db_path, str) or not db_path:
|
||||||
|
return None
|
||||||
|
m = _DB_TIMESTAMP_RE.search(Path(db_path).name)
|
||||||
|
return (m.group(1), m.group(2)) if m else None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_barrier_lookup(csv_path: Path) -> dict[tuple[str, str, str], dict]:
|
||||||
|
"""Build (machine, session_date, session_time) → opening/bad_rois lookup.
|
||||||
|
|
||||||
|
Returns an empty dict if the CSV is missing — callers should treat
|
||||||
|
that as "no per-session annotations available" rather than an error.
|
||||||
|
"""
|
||||||
|
if not Path(csv_path).exists():
|
||||||
|
return {}
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
lookup: dict[tuple[str, str, str], dict] = {}
|
||||||
|
for r in df.itertuples(index=False):
|
||||||
|
bad = set()
|
||||||
|
if isinstance(r.bad_rois, str) and r.bad_rois.strip():
|
||||||
|
bad = {int(x) for x in r.bad_rois.split(",") if x.strip()}
|
||||||
|
lookup[(r.machine_name, r.session_date, r.session_time)] = {
|
||||||
|
"opening_s": float(r.opening_s) if pd.notna(r.opening_s) else float("nan"),
|
||||||
|
"trim_first_s": float(r.trim_first_s) if pd.notna(r.trim_first_s) else 0.0,
|
||||||
|
"bad_rois": bad,
|
||||||
|
"unusable": pd.isna(r.opening_s),
|
||||||
|
}
|
||||||
|
return lookup
|
||||||
|
|
||||||
# Reason: prefer the explicit Jupyter-widget tqdm when available (it
|
# Reason: prefer the explicit Jupyter-widget tqdm when available (it
|
||||||
# updates reliably in JupyterLab, where text \r-style bars sometimes
|
# updates reliably in JupyterLab, where text \r-style bars sometimes
|
||||||
|
|
@ -66,6 +102,7 @@ def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None:
|
||||||
def load_roi_data(
|
def load_roi_data(
|
||||||
meta: pd.DataFrame | None = None,
|
meta: pd.DataFrame | None = None,
|
||||||
progress: bool = True,
|
progress: bool = True,
|
||||||
|
apply_barrier_filter: bool = True,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""Load ROI tracking data joined with experimental metadata.
|
"""Load ROI tracking data joined with experimental metadata.
|
||||||
|
|
||||||
|
|
@ -75,6 +112,14 @@ def load_roi_data(
|
||||||
(``"training"`` or ``"testing"``). Rows with empty DB paths (unusable
|
(``"training"`` or ``"testing"``). Rows with empty DB paths (unusable
|
||||||
videos, or videos that didn't pass the completeness gate) are skipped.
|
videos, or videos that didn't pass the completeness gate) are skipped.
|
||||||
|
|
||||||
|
Both training and testing reads are filtered against
|
||||||
|
``barrier_opening.csv`` (the picker annotates both video types):
|
||||||
|
flies whose ROI never released (listed in ``bad_rois``) and entire
|
||||||
|
sessions flagged unusable are dropped. The session's ``opening_s``
|
||||||
|
is stamped onto its samples so downstream code can compute
|
||||||
|
``t_from_opening = t - opening_s``. Sessions missing from the CSV
|
||||||
|
are still loaded, but with ``opening_s = NaN``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
meta: optional DataFrame with the same schema as
|
meta: optional DataFrame with the same schema as
|
||||||
``all_video_info_merged.tsv``. Pass a filtered slice to load a
|
``all_video_info_merged.tsv``. Pass a filtered slice to load a
|
||||||
|
|
@ -82,11 +127,17 @@ def load_roi_data(
|
||||||
Defaults to the full TSV.
|
Defaults to the full TSV.
|
||||||
progress: show a tqdm progress bar (one tick per fly/ROI row).
|
progress: show a tqdm progress bar (one tick per fly/ROI row).
|
||||||
Defaults to True. Set False for silent batch jobs.
|
Defaults to True. Set False for silent batch jobs.
|
||||||
|
apply_barrier_filter: if True (default), drop session data for
|
||||||
|
flies whose barrier never opened and stamp ``opening_s``
|
||||||
|
onto every sample. Set False to load raw data without any
|
||||||
|
barrier-derived filtering or columns.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred,
|
DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred,
|
||||||
has_interacted, session, <metadata>`` — one row per tracking
|
has_interacted, session, ROI, opening_s, <metadata>`` — one row
|
||||||
sample. Empty if nothing could be loaded.
|
per tracking sample. ``opening_s`` is NaN for sessions not
|
||||||
|
covered by ``barrier_opening.csv``. Empty if nothing could be
|
||||||
|
loaded.
|
||||||
"""
|
"""
|
||||||
if meta is None:
|
if meta is None:
|
||||||
meta = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
|
meta = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
|
||||||
|
|
@ -97,8 +148,12 @@ def load_roi_data(
|
||||||
if "include" in meta.columns:
|
if "include" in meta.columns:
|
||||||
meta = meta[meta["include"].astype(bool)]
|
meta = meta[meta["include"].astype(bool)]
|
||||||
|
|
||||||
|
barrier_lookup = _load_barrier_lookup(BARRIER_OPENING_CSV) if apply_barrier_filter else {}
|
||||||
|
|
||||||
db_cache: dict = {}
|
db_cache: dict = {}
|
||||||
chunks: list[pd.DataFrame] = []
|
chunks: list[pd.DataFrame] = []
|
||||||
|
n_skipped_bad_roi = 0
|
||||||
|
n_skipped_unusable = 0
|
||||||
|
|
||||||
n_rows = len(meta)
|
n_rows = len(meta)
|
||||||
if progress:
|
if progress:
|
||||||
|
|
@ -125,7 +180,28 @@ def load_roi_data(
|
||||||
for row in meta.itertuples(index=False):
|
for row in meta.itertuples(index=False):
|
||||||
for session in ("training", "testing"):
|
for session in ("training", "testing"):
|
||||||
pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}")
|
pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}")
|
||||||
conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache)
|
db_path = getattr(row, f"{session}_db_path")
|
||||||
|
|
||||||
|
# The picker annotates barrier_opening per video, and both
|
||||||
|
# the training and testing videos have their own entries.
|
||||||
|
# Apply the same per-session filter to both.
|
||||||
|
opening_s = float("nan")
|
||||||
|
if barrier_lookup:
|
||||||
|
key = _session_key(db_path)
|
||||||
|
if key is not None:
|
||||||
|
bo = barrier_lookup.get((row.machine_name, key[0], key[1]))
|
||||||
|
if bo is not None:
|
||||||
|
if bo["unusable"]:
|
||||||
|
n_skipped_unusable += 1
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
if int(row.roi) in bo["bad_rois"]:
|
||||||
|
n_skipped_bad_roi += 1
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
opening_s = bo["opening_s"]
|
||||||
|
|
||||||
|
conn = _open_ro(db_path, db_cache)
|
||||||
if conn is None:
|
if conn is None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
@ -141,6 +217,7 @@ def load_roi_data(
|
||||||
continue
|
continue
|
||||||
df["session"] = session
|
df["session"] = session
|
||||||
df["ROI"] = int(row.roi)
|
df["ROI"] = int(row.roi)
|
||||||
|
df["opening_s"] = opening_s
|
||||||
for col in _META_COLS:
|
for col in _META_COLS:
|
||||||
df[col] = getattr(row, col)
|
df[col] = getattr(row, col)
|
||||||
chunks.append(df)
|
chunks.append(df)
|
||||||
|
|
@ -148,6 +225,13 @@ def load_roi_data(
|
||||||
|
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
|
if apply_barrier_filter and (n_skipped_bad_roi or n_skipped_unusable):
|
||||||
|
print(
|
||||||
|
f"Barrier filter: dropped {n_skipped_bad_roi} ROI loads (barrier "
|
||||||
|
f"never opened) and {n_skipped_unusable} unusable sessions.",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
for conn in db_cache.values():
|
for conn in db_cache.values():
|
||||||
if conn is not None:
|
if conn is not None:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue