"""Headless offline tracker. Reads target JSONs produced by `pick_targets.py`, builds the 6 ROIs of the HD mating arena from the L-shape reference points, runs ethoscope's `MultiFlyTracker` against the merged.mp4 file via `MovieVirtualCamera`, and writes a SQLite DB to `TRACKING_OUTPUT_DIR/_tracking.db`. Idempotent: skips videos whose tracking DB already exists (unless --redo). Usage: python track_videos.py # process all videos with target JSON python track_videos.py --redo # re-track even if DB exists python track_videos.py --jobs 4 # run up to 4 videos in parallel python track_videos.py --max-duration 1800 # cap each video at 30 min (sec) """ from __future__ import annotations import argparse import json import logging import os import sys import traceback from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path import numpy as np from config import ETHOSCOPE_SRC, TARGETS_DIR, TRACKING_OUTPUT_DIR # Import ethoscope from the local source tree (no pip install). sys.path.insert(0, str(ETHOSCOPE_SRC)) from tracking_geometry import HD_FG_DATA, compute_roi_polygons # noqa: E402 def build_rois_from_targets(reference_points): """Wrap the shared geometry into ethoscope `ROI` objects.""" from ethoscope.core.roi import ROI polys = compute_roi_polygons(reference_points) return [ROI(poly.reshape((1, 4, 2)), idx=i + 1) for i, poly in enumerate(polys)] def track_one(json_path: Path, output_dir: Path, max_duration: float | None, redo: bool) -> tuple[str, str]: """Track a single video. Returns (status, message). Run in subprocess. Statuses: "ok", "skip", "error". """ # Re-import inside subprocess so each worker has its own ethoscope state. import sys as _sys _sys.path.insert(0, str(ETHOSCOPE_SRC)) import cv2 from ethoscope.core.monitor import Monitor from ethoscope.hardware.input.cameras import MovieVirtualCamera from ethoscope.io.sqlite import SQLiteResultWriter from ethoscope.trackers.multi_fly_tracker import MultiFlyTracker import time as _time class BGRMovieCamera(MovieVirtualCamera): """MovieVirtualCamera that keeps BGR frames AND retries on transient read failures. Two reasons for the override: 1. MultiFlyTracker calls cv2.cvtColor(img, COLOR_BGR2GRAY) without checking whether img is already grayscale, so we must feed it 3-channel input. 2. cv2.VideoCapture.read() can return False on transient I/O hiccups (NFS contention when 8 workers pull big mp4s in parallel) without the file actually being at EOF. A naive "False -> StopIteration" handling makes the tracker silently exit mid-video and write a short, lying DB. We retry a few times and only treat persistent failures within the *interior* of the video as real EOF. """ _retry_count = 5 _retry_backoff_s = 0.25 _eof_safety_frames = 50 # near end-of-file, treat False as legitimate def _next_image(self): for attempt in range(self._retry_count): ret, frame = self.capture.read() if ret and frame is not None: return frame # BGR, untouched # If we're near the genuine end of the file, accept it. if ( self._has_end_of_file and self._frame_idx >= self._total_n_frames - self._eof_safety_frames ): return None # Otherwise, this is a suspected transient hiccup — back off # and try again. The capture is still open; cv2 will pick up # the next decoded frame. _time.sleep(self._retry_backoff_s) return None # truly persistent failure payload = json.loads(json_path.read_text()) if payload.get("unusable"): reason = payload.get("reason") or "no reason given" return "skip", f"marked unusable: {reason}" video_path = Path(payload["video_path"]) if not video_path.exists(): return "error", f"video missing: {video_path}" out_db = output_dir / f"{video_path.stem}_tracking.db" if out_db.exists() and not redo: return "skip", f"DB exists: {out_db.name}" if out_db.exists(): out_db.unlink() rois = build_rois_from_targets(payload["reference_points"]) cam_kwargs = {"use_wall_clock": False} if max_duration is not None: cam_kwargs["max_duration"] = max_duration cam = BGRMovieCamera(str(video_path), **cam_kwargs) metadata = { "machine_id": payload.get("machine_uuid", "unknown"), "machine_name": payload.get("machine_name", "unknown"), "date_time": int(payload.get("session_epoch", 0)), "frame_width": cam.width, "frame_height": cam.height, "version": "offline-tracker-1", "experimental_info": "{}", "selected_options": json.dumps({ "tracker": "MultiFlyTracker", "template": "HD_Mating_Arena_6_ROIS", "fg_data": HD_FG_DATA, "maxN": 2, }), "hardware_info": "{}", "reference_points": str([list(map(int, p)) for p in payload["reference_points"]]), "backup_filename": out_db.name, "result_writer_type": "SQLite3", "sqlite_source_path": str(out_db), } tracker_data = { "maxN": 2, "visualise": False, "fg_data": HD_FG_DATA, "adaptive_threshold": True, "min_fg_threshold": 10, "max_fg_threshold": 50, } db_credentials = {"name": str(out_db)} rw = SQLiteResultWriter( db_credentials, rois, metadata=metadata, make_dam_like_table=False, take_frame_shots=False, erase_old_db=True, ) monit = Monitor( cam, MultiFlyTracker, rois, reference_points=payload["reference_points"], data=tracker_data, ) try: with rw as result_writer: monit.run(result_writer=result_writer, drawer=None, verbose=False) except Exception: return "error", traceback.format_exc(limit=5) finally: try: cam._close() except Exception: pass if not out_db.exists(): return "error", "tracking finished but DB was not created" # Post-tracking sanity check: did we cover most of the source video? # If not (cv2 retry exhausted, codec corruption, etc.), reject the DB so # it doesn't get cached as "done" — better an explicit failure than a # silent partial write. expected_ms = (cam._total_n_frames / 25.0) * 1000.0 if max_duration is not None: expected_ms = min(expected_ms, max_duration * 1000.0) completeness_threshold = 0.90 # require ≥ 90 % of expected duration # Use MAX(t) across all ROIs — a single ROI can run dry early if its fly # stops moving, so the latest detection anywhere in the arena is the # better signal of how far the iterator actually got. import sqlite3 as _sqlite3 try: _con = _sqlite3.connect(f"file:{out_db}?mode=ro", uri=True) t_max = 0 for _i in range(1, 7): _v = _con.execute(f"SELECT MAX(t) FROM ROI_{_i}").fetchone()[0] if _v and _v > t_max: t_max = _v _con.close() except Exception: t_max = 0 if expected_ms > 0 and t_max < expected_ms * completeness_threshold: out_db.unlink() for sidecar in (str(out_db) + "-wal", str(out_db) + "-shm"): Path(sidecar).unlink(missing_ok=True) ratio = t_max / expected_ms if expected_ms else 0 return ( "error", f"short output: t_max={t_max} ms vs expected {int(expected_ms)} ms " f"({ratio*100:.0f}%); DB removed", ) return "ok", str(out_db) def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--redo", action="store_true", help="re-track even if DB exists") parser.add_argument("--jobs", type=int, default=1, help="parallel workers") parser.add_argument( "--max-duration", type=float, default=None, help="cap each video at this many seconds (default: full video)", ) parser.add_argument("--limit", type=int, default=None, help="process only first N") parser.add_argument("--video", type=str, default=None, help="track a single video (mp4 path); requires its target JSON") args = parser.parse_args() TRACKING_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) if args.video: stem = Path(args.video).stem json_path = TARGETS_DIR / f"{stem}.json" if not json_path.exists(): sys.exit(f"No target JSON for {args.video}: expected {json_path}") jsons = [json_path] else: jsons = sorted(TARGETS_DIR.glob("*.json")) if args.limit: jsons = jsons[: args.limit] if not jsons: print("No target JSONs found. Run pick_targets.py first.") return print(f"Tracking {len(jsons)} videos (jobs={args.jobs}, redo={args.redo}).") n_ok = n_skip = n_err = 0 if args.jobs <= 1: for jp in jsons: print(f" → {jp.name}", flush=True) status, msg = track_one(jp, TRACKING_OUTPUT_DIR, args.max_duration, args.redo) print(f" {status}: {msg.splitlines()[-1] if msg else ''}", flush=True) n_ok += status == "ok" n_skip += status == "skip" n_err += status == "error" else: with ProcessPoolExecutor(max_workers=args.jobs) as ex: futs = { ex.submit(track_one, jp, TRACKING_OUTPUT_DIR, args.max_duration, args.redo): jp for jp in jsons } for fut in as_completed(futs): jp = futs[fut] try: status, msg = fut.result() except Exception as e: status, msg = "error", f"future raised: {e}" print(f" {jp.name}: {status} — {msg.splitlines()[-1] if msg else ''}", flush=True) n_ok += status == "ok" n_skip += status == "skip" n_err += status == "error" print(f"\nDone. ok={n_ok} skipped={n_skip} errors={n_err}") sys.exit(0 if n_err == 0 else 1) if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") main()