cupido/scripts/track_videos.py
Giorgio Gilestro 231c7a437f Remove hardcoded /home/gg paths so the project is portable
Notebooks now use Path.home() / "cupido" for the repo root (works for
any user inside the JupyterLab container), and the offline-tracking
scripts read the ethoscope source-tree location from the new
ETHOSCOPE_SRC config constant — defaulting to ~/Code/ethoscope_project/...
and overridable via the ETHOSCOPE_SRC environment variable.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-01 08:55:44 +01:00

283 lines
10 KiB
Python

"""Headless offline tracker.
Reads target JSONs produced by `pick_targets.py`, builds the 6 ROIs of the
HD mating arena from the L-shape reference points, runs ethoscope's
`MultiFlyTracker` against the merged.mp4 file via `MovieVirtualCamera`, and
writes a SQLite DB to `TRACKING_OUTPUT_DIR/<video_basename>_tracking.db`.
Idempotent: skips videos whose tracking DB already exists (unless --redo).
Usage:
python track_videos.py # process all videos with target JSON
python track_videos.py --redo # re-track even if DB exists
python track_videos.py --jobs 4 # run up to 4 videos in parallel
python track_videos.py --max-duration 1800 # cap each video at 30 min (sec)
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import numpy as np
from config import ETHOSCOPE_SRC, TARGETS_DIR, TRACKING_OUTPUT_DIR
# Import ethoscope from the local source tree (no pip install).
sys.path.insert(0, str(ETHOSCOPE_SRC))
from tracking_geometry import HD_FG_DATA, compute_roi_polygons # noqa: E402
def build_rois_from_targets(reference_points):
"""Wrap the shared geometry into ethoscope `ROI` objects."""
from ethoscope.core.roi import ROI
polys = compute_roi_polygons(reference_points)
return [ROI(poly.reshape((1, 4, 2)), idx=i + 1) for i, poly in enumerate(polys)]
def track_one(json_path: Path, output_dir: Path, max_duration: float | None,
redo: bool) -> tuple[str, str]:
"""Track a single video. Returns (status, message). Run in subprocess.
Statuses: "ok", "skip", "error".
"""
# Re-import inside subprocess so each worker has its own ethoscope state.
import sys as _sys
_sys.path.insert(0, str(ETHOSCOPE_SRC))
import cv2
from ethoscope.core.monitor import Monitor
from ethoscope.hardware.input.cameras import MovieVirtualCamera
from ethoscope.io.sqlite import SQLiteResultWriter
from ethoscope.trackers.multi_fly_tracker import MultiFlyTracker
import time as _time
class BGRMovieCamera(MovieVirtualCamera):
"""MovieVirtualCamera that keeps BGR frames AND retries on transient
read failures.
Two reasons for the override:
1. MultiFlyTracker calls cv2.cvtColor(img, COLOR_BGR2GRAY) without
checking whether img is already grayscale, so we must feed it
3-channel input.
2. cv2.VideoCapture.read() can return False on transient I/O hiccups
(NFS contention when 8 workers pull big mp4s in parallel) without
the file actually being at EOF. A naive "False -> StopIteration"
handling makes the tracker silently exit mid-video and write a
short, lying DB. We retry a few times and only treat persistent
failures within the *interior* of the video as real EOF.
"""
_retry_count = 5
_retry_backoff_s = 0.25
_eof_safety_frames = 50 # near end-of-file, treat False as legitimate
def _next_image(self):
for attempt in range(self._retry_count):
ret, frame = self.capture.read()
if ret and frame is not None:
return frame # BGR, untouched
# If we're near the genuine end of the file, accept it.
if (
self._has_end_of_file
and self._frame_idx >= self._total_n_frames - self._eof_safety_frames
):
return None
# Otherwise, this is a suspected transient hiccup — back off
# and try again. The capture is still open; cv2 will pick up
# the next decoded frame.
_time.sleep(self._retry_backoff_s)
return None # truly persistent failure
payload = json.loads(json_path.read_text())
if payload.get("unusable"):
reason = payload.get("reason") or "no reason given"
return "skip", f"marked unusable: {reason}"
video_path = Path(payload["video_path"])
if not video_path.exists():
return "error", f"video missing: {video_path}"
out_db = output_dir / f"{video_path.stem}_tracking.db"
if out_db.exists() and not redo:
return "skip", f"DB exists: {out_db.name}"
if out_db.exists():
out_db.unlink()
rois = build_rois_from_targets(payload["reference_points"])
cam_kwargs = {"use_wall_clock": False}
if max_duration is not None:
cam_kwargs["max_duration"] = max_duration
cam = BGRMovieCamera(str(video_path), **cam_kwargs)
metadata = {
"machine_id": payload.get("machine_uuid", "unknown"),
"machine_name": payload.get("machine_name", "unknown"),
"date_time": int(payload.get("session_epoch", 0)),
"frame_width": cam.width,
"frame_height": cam.height,
"version": "offline-tracker-1",
"experimental_info": "{}",
"selected_options": json.dumps({
"tracker": "MultiFlyTracker",
"template": "HD_Mating_Arena_6_ROIS",
"fg_data": HD_FG_DATA,
"maxN": 2,
}),
"hardware_info": "{}",
"reference_points": str([list(map(int, p)) for p in payload["reference_points"]]),
"backup_filename": out_db.name,
"result_writer_type": "SQLite3",
"sqlite_source_path": str(out_db),
}
tracker_data = {
"maxN": 2,
"visualise": False,
"fg_data": HD_FG_DATA,
"adaptive_threshold": True,
"min_fg_threshold": 10,
"max_fg_threshold": 50,
}
db_credentials = {"name": str(out_db)}
rw = SQLiteResultWriter(
db_credentials, rois, metadata=metadata,
make_dam_like_table=False, take_frame_shots=False, erase_old_db=True,
)
monit = Monitor(
cam, MultiFlyTracker, rois,
reference_points=payload["reference_points"],
data=tracker_data,
)
try:
with rw as result_writer:
monit.run(result_writer=result_writer, drawer=None, verbose=False)
except Exception:
return "error", traceback.format_exc(limit=5)
finally:
try:
cam._close()
except Exception:
pass
if not out_db.exists():
return "error", "tracking finished but DB was not created"
# Post-tracking sanity check: did we cover most of the source video?
# If not (cv2 retry exhausted, codec corruption, etc.), reject the DB so
# it doesn't get cached as "done" — better an explicit failure than a
# silent partial write.
expected_ms = (cam._total_n_frames / 25.0) * 1000.0
if max_duration is not None:
expected_ms = min(expected_ms, max_duration * 1000.0)
completeness_threshold = 0.90 # require ≥ 90 % of expected duration
# Use MAX(t) across all ROIs — a single ROI can run dry early if its fly
# stops moving, so the latest detection anywhere in the arena is the
# better signal of how far the iterator actually got.
import sqlite3 as _sqlite3
try:
_con = _sqlite3.connect(f"file:{out_db}?mode=ro", uri=True)
t_max = 0
for _i in range(1, 7):
_v = _con.execute(f"SELECT MAX(t) FROM ROI_{_i}").fetchone()[0]
if _v and _v > t_max:
t_max = _v
_con.close()
except Exception:
t_max = 0
if expected_ms > 0 and t_max < expected_ms * completeness_threshold:
out_db.unlink()
for sidecar in (str(out_db) + "-wal", str(out_db) + "-shm"):
Path(sidecar).unlink(missing_ok=True)
ratio = t_max / expected_ms if expected_ms else 0
return (
"error",
f"short output: t_max={t_max} ms vs expected {int(expected_ms)} ms "
f"({ratio*100:.0f}%); DB removed",
)
return "ok", str(out_db)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--redo", action="store_true", help="re-track even if DB exists")
parser.add_argument("--jobs", type=int, default=1, help="parallel workers")
parser.add_argument(
"--max-duration", type=float, default=None,
help="cap each video at this many seconds (default: full video)",
)
parser.add_argument("--limit", type=int, default=None, help="process only first N")
parser.add_argument("--video", type=str, default=None,
help="track a single video (mp4 path); requires its target JSON")
args = parser.parse_args()
TRACKING_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
if args.video:
stem = Path(args.video).stem
json_path = TARGETS_DIR / f"{stem}.json"
if not json_path.exists():
sys.exit(f"No target JSON for {args.video}: expected {json_path}")
jsons = [json_path]
else:
jsons = sorted(TARGETS_DIR.glob("*.json"))
if args.limit:
jsons = jsons[: args.limit]
if not jsons:
print("No target JSONs found. Run pick_targets.py first.")
return
print(f"Tracking {len(jsons)} videos (jobs={args.jobs}, redo={args.redo}).")
n_ok = n_skip = n_err = 0
if args.jobs <= 1:
for jp in jsons:
print(f"{jp.name}", flush=True)
status, msg = track_one(jp, TRACKING_OUTPUT_DIR, args.max_duration, args.redo)
print(f" {status}: {msg.splitlines()[-1] if msg else ''}", flush=True)
n_ok += status == "ok"
n_skip += status == "skip"
n_err += status == "error"
else:
with ProcessPoolExecutor(max_workers=args.jobs) as ex:
futs = {
ex.submit(track_one, jp, TRACKING_OUTPUT_DIR, args.max_duration, args.redo): jp
for jp in jsons
}
for fut in as_completed(futs):
jp = futs[fut]
try:
status, msg = fut.result()
except Exception as e:
status, msg = "error", f"future raised: {e}"
print(f" {jp.name}: {status}{msg.splitlines()[-1] if msg else ''}",
flush=True)
n_ok += status == "ok"
n_skip += status == "skip"
n_err += status == "error"
print(f"\nDone. ok={n_ok} skipped={n_skip} errors={n_err}")
sys.exit(0 if n_err == 0 else 1)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
main()