load_roi_data: filter on barrier_opening.csv and stamp opening_s
For every session (training and testing alike), the loader now looks
up the corresponding row in barrier_opening.csv and:
- drops the read if the ROI is in bad_rois (barrier never opened
for that fly so its tracking has no biological meaning)
- drops the read if the session is flagged unusable
- stamps the session's opening_s onto every sample so downstream
code can compute t_from_opening = t - opening_s
Tested against ETHOSCOPE_082 2024-09-17: training (bad_rois=1,3,5)
correctly drops ROIs 1/3/5; testing keeps all six; opening_s differs
between sessions as expected (646.8 vs 154.7).
Opt out with apply_barrier_filter=False if you need raw data.
This commit is contained in:
parent
b8f23a4884
commit
28b7a227c0
2 changed files with 93 additions and 4 deletions
|
|
@ -26,6 +26,11 @@ VIDEO_INFO_TSV = DATA_VOLUME / "all_video_info_merged.tsv"
|
|||
# A small CSV listing every video file we know about (built locally).
|
||||
INVENTORY_CSV = DATA_METADATA / "video_inventory.csv"
|
||||
|
||||
# Hand-annotated barrier-opening times (output of the picker app). One
|
||||
# row per testing session; columns: machine_name, session_date,
|
||||
# session_time, opening_s, trim_first_s, bad_rois, analyst, notes.
|
||||
BARRIER_OPENING_CSV = DATA_METADATA / "barrier_opening.csv"
|
||||
|
||||
# Where the ethoscope source tree is checked out (used by track_videos.py
|
||||
# and auto_detect_targets.py — host-side scripts that import ethoscope
|
||||
# from a local clone rather than from pip). Default assumes the standard
|
||||
|
|
|
|||
|
|
@ -8,12 +8,48 @@ The TSV is the single source of truth for what data exists and how it
|
|||
maps to flies and conditions.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from config import VIDEO_INFO_TSV
|
||||
from config import BARRIER_OPENING_CSV, VIDEO_INFO_TSV
|
||||
|
||||
# DB filenames start with `YYYY-MM-DD_HH-MM-SS_<uuid>_...` — pull the
|
||||
# session date/time out so we can join against barrier_opening.csv.
|
||||
_DB_TIMESTAMP_RE = re.compile(r"(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2}-\d{2})_")
|
||||
|
||||
|
||||
def _session_key(db_path: str) -> tuple[str, str] | None:
|
||||
"""Extract (session_date, session_time) from a tracking DB filename."""
|
||||
if not isinstance(db_path, str) or not db_path:
|
||||
return None
|
||||
m = _DB_TIMESTAMP_RE.search(Path(db_path).name)
|
||||
return (m.group(1), m.group(2)) if m else None
|
||||
|
||||
|
||||
def _load_barrier_lookup(csv_path: Path) -> dict[tuple[str, str, str], dict]:
|
||||
"""Build (machine, session_date, session_time) → opening/bad_rois lookup.
|
||||
|
||||
Returns an empty dict if the CSV is missing — callers should treat
|
||||
that as "no per-session annotations available" rather than an error.
|
||||
"""
|
||||
if not Path(csv_path).exists():
|
||||
return {}
|
||||
df = pd.read_csv(csv_path)
|
||||
lookup: dict[tuple[str, str, str], dict] = {}
|
||||
for r in df.itertuples(index=False):
|
||||
bad = set()
|
||||
if isinstance(r.bad_rois, str) and r.bad_rois.strip():
|
||||
bad = {int(x) for x in r.bad_rois.split(",") if x.strip()}
|
||||
lookup[(r.machine_name, r.session_date, r.session_time)] = {
|
||||
"opening_s": float(r.opening_s) if pd.notna(r.opening_s) else float("nan"),
|
||||
"trim_first_s": float(r.trim_first_s) if pd.notna(r.trim_first_s) else 0.0,
|
||||
"bad_rois": bad,
|
||||
"unusable": pd.isna(r.opening_s),
|
||||
}
|
||||
return lookup
|
||||
|
||||
# Reason: prefer the explicit Jupyter-widget tqdm when available (it
|
||||
# updates reliably in JupyterLab, where text \r-style bars sometimes
|
||||
|
|
@ -66,6 +102,7 @@ def _open_ro(db_path: str, cache: dict) -> sqlite3.Connection | None:
|
|||
def load_roi_data(
|
||||
meta: pd.DataFrame | None = None,
|
||||
progress: bool = True,
|
||||
apply_barrier_filter: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""Load ROI tracking data joined with experimental metadata.
|
||||
|
||||
|
|
@ -75,6 +112,14 @@ def load_roi_data(
|
|||
(``"training"`` or ``"testing"``). Rows with empty DB paths (unusable
|
||||
videos, or videos that didn't pass the completeness gate) are skipped.
|
||||
|
||||
Both training and testing reads are filtered against
|
||||
``barrier_opening.csv`` (the picker annotates both video types):
|
||||
flies whose ROI never released (listed in ``bad_rois``) and entire
|
||||
sessions flagged unusable are dropped. The session's ``opening_s``
|
||||
is stamped onto its samples so downstream code can compute
|
||||
``t_from_opening = t - opening_s``. Sessions missing from the CSV
|
||||
are still loaded, but with ``opening_s = NaN``.
|
||||
|
||||
Args:
|
||||
meta: optional DataFrame with the same schema as
|
||||
``all_video_info_merged.tsv``. Pass a filtered slice to load a
|
||||
|
|
@ -82,11 +127,17 @@ def load_roi_data(
|
|||
Defaults to the full TSV.
|
||||
progress: show a tqdm progress bar (one tick per fly/ROI row).
|
||||
Defaults to True. Set False for silent batch jobs.
|
||||
apply_barrier_filter: if True (default), drop session data for
|
||||
flies whose barrier never opened and stamp ``opening_s``
|
||||
onto every sample. Set False to load raw data without any
|
||||
barrier-derived filtering or columns.
|
||||
|
||||
Returns:
|
||||
DataFrame with columns ``id, t, x, y, w, h, phi, is_inferred,
|
||||
has_interacted, session, <metadata>`` — one row per tracking
|
||||
sample. Empty if nothing could be loaded.
|
||||
has_interacted, session, ROI, opening_s, <metadata>`` — one row
|
||||
per tracking sample. ``opening_s`` is NaN for sessions not
|
||||
covered by ``barrier_opening.csv``. Empty if nothing could be
|
||||
loaded.
|
||||
"""
|
||||
if meta is None:
|
||||
meta = pd.read_csv(VIDEO_INFO_TSV, sep="\t")
|
||||
|
|
@ -97,8 +148,12 @@ def load_roi_data(
|
|||
if "include" in meta.columns:
|
||||
meta = meta[meta["include"].astype(bool)]
|
||||
|
||||
barrier_lookup = _load_barrier_lookup(BARRIER_OPENING_CSV) if apply_barrier_filter else {}
|
||||
|
||||
db_cache: dict = {}
|
||||
chunks: list[pd.DataFrame] = []
|
||||
n_skipped_bad_roi = 0
|
||||
n_skipped_unusable = 0
|
||||
|
||||
n_rows = len(meta)
|
||||
if progress:
|
||||
|
|
@ -125,7 +180,28 @@ def load_roi_data(
|
|||
for row in meta.itertuples(index=False):
|
||||
for session in ("training", "testing"):
|
||||
pbar.set_postfix_str(f"{row.machine_name} ROI {int(row.roi)} {session}")
|
||||
conn = _open_ro(getattr(row, f"{session}_db_path"), db_cache)
|
||||
db_path = getattr(row, f"{session}_db_path")
|
||||
|
||||
# The picker annotates barrier_opening per video, and both
|
||||
# the training and testing videos have their own entries.
|
||||
# Apply the same per-session filter to both.
|
||||
opening_s = float("nan")
|
||||
if barrier_lookup:
|
||||
key = _session_key(db_path)
|
||||
if key is not None:
|
||||
bo = barrier_lookup.get((row.machine_name, key[0], key[1]))
|
||||
if bo is not None:
|
||||
if bo["unusable"]:
|
||||
n_skipped_unusable += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
if int(row.roi) in bo["bad_rois"]:
|
||||
n_skipped_bad_roi += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
opening_s = bo["opening_s"]
|
||||
|
||||
conn = _open_ro(db_path, db_cache)
|
||||
if conn is None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
|
@ -141,6 +217,7 @@ def load_roi_data(
|
|||
continue
|
||||
df["session"] = session
|
||||
df["ROI"] = int(row.roi)
|
||||
df["opening_s"] = opening_s
|
||||
for col in _META_COLS:
|
||||
df[col] = getattr(row, col)
|
||||
chunks.append(df)
|
||||
|
|
@ -148,6 +225,13 @@ def load_roi_data(
|
|||
|
||||
pbar.close()
|
||||
|
||||
if apply_barrier_filter and (n_skipped_bad_roi or n_skipped_unusable):
|
||||
print(
|
||||
f"Barrier filter: dropped {n_skipped_bad_roi} ROI loads (barrier "
|
||||
f"never opened) and {n_skipped_unusable} unusable sessions.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
for conn in db_cache.values():
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue