cupido/notebooks/getting_started/03_compare_trained_vs_naive.ipynb
Giorgio Gilestro ec56e51bf9 Add beginner tutorial notebooks for incoming students
Four guided notebooks under notebooks/getting_started/ aimed at someone
new to Python and data science. The series progresses: project orientation
→ Python/pandas crash course → exploring one tracking DB → first
trained-vs-naive comparison using load_roi_data + Mann-Whitney U.

Each notebook leans heavily on markdown explanations, includes exercises
with empty cells, and links out to canonical references (JupyterLab,
official Python tutorial, pandas 10-min guide, Wikipedia for stats
concepts).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-30 18:14:17 +01:00

398 lines
13 KiB
Text

{
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 03 \u00b7 Your first real analysis: trained vs naive\n",
"\n",
"In notebook 02 we explored a single database. Now we'll work with **all\n",
"of them at once**, compute a simple per-fly metric, and ask the central\n",
"question of the project:\n",
"\n",
"> **Do trained males behave differently from na\u00efve males in the testing\n",
"> session?**\n",
"\n",
"By the end you'll have:\n",
"\n",
"- loaded every (fly, session) trace into one big DataFrame using the\n",
" project's helper function;\n",
"- reduced each trace to one number per fly (the *median inter-fly\n",
" distance*);\n",
"- compared the trained group against the na\u00efve group with a histogram\n",
" and a non-parametric statistical test;\n",
"- learnt enough to start asking your own questions.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from scipy import stats\n",
"\n",
"# Tell Python where to find the project's helper modules.\n",
"PROJECT_ROOT = Path(\"..\").resolve().parent # this notebook is in notebooks/getting_started/\n",
"sys.path.insert(0, str(PROJECT_ROOT / \"scripts\"))\n",
"\n",
"from load_roi_data import load_roi_data\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading everything at once \u2014 but carefully\n",
"\n",
"`load_roi_data()` opens every tracking DB referenced by the metadata TSV\n",
"and returns one big DataFrame. **It can be slow and memory-hungry**\n",
"(the full batch is ~200 million rows). Always start small.\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"# Load the metadata TSV first \u2014 it's small and fast.\n",
"tsv_path = \"/home/gg/ownCloud/Work/Projects/coding/cupido/all_video_info_merged.tsv\"\n",
"meta = pd.read_csv(tsv_path, sep=\"\\t\")\n",
"print(f\"metadata rows: {len(meta)}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pre-filter the metadata before passing it to `load_roi_data`. We'll start\n",
"with **just one species and just the testing sessions**, because:\n",
"\n",
"1. mixing species is a confound (different species behave differently);\n",
"2. the question is about behaviour after training, so the testing session\n",
" is the relevant one;\n",
"3. starting small means we can iterate quickly.\n",
"\n",
"You can come back later and broaden this filter.\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"# Pick one species. 'Melanogaster/CS' has the most rows (127), so a good default.\n",
"sub = meta[meta[\"species\"] == \"Melanogaster/CS\"].copy()\n",
"\n",
"# We're loading every session for these flies, but the loader stamps each\n",
"# row with a 'session' column so we can filter to testing afterwards.\n",
"print(f\"selected metadata rows: {len(sub)}\")\n",
"print(sub[\"male\"].value_counts())\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"# This will take a minute or two and use a chunk of RAM. Be patient.\n",
"data = load_roi_data(sub)\n",
"print(f\"loaded shape: {data.shape}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What did we get?\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"data.head(3)\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"# How big is each session, in tracking samples?\n",
"data.groupby([\"session\", \"male\"]).size().unstack(fill_value=0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Restrict to the testing session\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"testing = data[data[\"session\"] == \"testing\"].copy()\n",
"print(f\"testing samples: {len(testing):,}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reduce each trace to one number\n",
"\n",
"Right now each fly contributes **tens of thousands** of (t, x, y) rows.\n",
"We can't compare distributions of millions of points across two groups\n",
"in any meaningful way. So we **collapse each (date, machine_name, ROI)\n",
"trace into a single summary number** \u2014 here, the median distance between\n",
"the two flies during testing.\n",
"\n",
"Why median rather than mean? Because tracker glitches (one fly\n",
"temporarily lost) can produce huge spikes that the median ignores.\n",
"[Why medians beat means in noisy data\n",
"(2-min read)](https://en.wikipedia.org/wiki/Median#Robustness).\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"# Step 1 \u2014 per-frame distance.\n",
"# Take only frames with exactly 2 flies (so we have a real distance).\n",
"two_fly = testing.groupby([\"date\", \"machine_name\", \"ROI\", \"t\"]).filter(lambda g: len(g) == 2)\n",
"\n",
"# For each (track, t), compute the distance between the two rows.\n",
"def distance_for_frame(g):\n",
" g = g.sort_values(\"id\").reset_index(drop=True)\n",
" return np.hypot(g.loc[0, \"x\"] - g.loc[1, \"x\"], g.loc[0, \"y\"] - g.loc[1, \"y\"])\n",
"\n",
"# This is the slow step. With ~3 M frames it takes a while.\n",
"per_frame = (\n",
" two_fly\n",
" .groupby([\"date\", \"machine_name\", \"ROI\", \"t\", \"male\"])\n",
" .apply(distance_for_frame)\n",
" .reset_index(name=\"distance_px\")\n",
")\n",
"print(f\"per-frame distance rows: {len(per_frame):,}\")\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"# Step 2 \u2014 one number per (date, machine_name, ROI).\n",
"per_fly = (\n",
" per_frame\n",
" .groupby([\"date\", \"machine_name\", \"ROI\", \"male\"])[\"distance_px\"]\n",
" .median()\n",
" .reset_index(name=\"median_distance_px\")\n",
")\n",
"\n",
"# Each row now is \"one fly during testing\", with its median distance.\n",
"print(per_fly.shape)\n",
"per_fly.head()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sanity check: how many flies per group?\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"per_fly[\"male\"].value_counts()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the numbers are very different, your statistical comparison will be\n",
"underpowered for one side. Note them down.\n",
"\n",
"## Plot the distributions\n",
"\n",
"The first thing to do with two groups is to **look at them**. Don't trust\n",
"a p-value before you've seen the histogram.\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(10, 5))\n",
"\n",
"bins = np.linspace(0, per_fly[\"median_distance_px\"].max(), 40)\n",
"\n",
"for label, color in [(\"trained\", \"steelblue\"), (\"naive\", \"darkorange\")]:\n",
" sub = per_fly[per_fly[\"male\"] == label][\"median_distance_px\"]\n",
" ax.hist(sub, bins=bins, alpha=0.6, label=f\"{label} (n={len(sub)})\", color=color)\n",
"\n",
"ax.set_xlabel(\"median inter-fly distance during testing (px)\")\n",
"ax.set_ylabel(\"number of flies\")\n",
"ax.set_title(\"Trained vs na\u00efve \u2014 Melanogaster/CS \u2014 testing session\")\n",
"ax.legend()\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**What you might see:**\n",
"\n",
"- If the trained group's distribution is shifted to **higher** distances,\n",
" trained males are spending less time near the female (i.e. they\n",
" learned to give up).\n",
"- If the two distributions look identical, no learning effect was\n",
" measurable with this metric \u2014 but that doesn't mean there's no effect,\n",
" just that this particular summary didn't capture it.\n",
"- A **bimodal** trained distribution (two humps) would mean some males\n",
" learned and others didn't \u2014 the \"individual differences\" story in\n",
" `docs/bimodal_hypothesis.md`.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add a stat test\n",
"\n",
"A formal comparison. Because group sizes are small and we don't know if\n",
"the data are normally distributed, the\n",
"[Mann-Whitney U test](https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test)\n",
"is a safer default than the classic t-test.\n"
]
},
{
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": [],
"source": [
"trained_vals = per_fly[per_fly[\"male\"] == \"trained\"][\"median_distance_px\"]\n",
"naive_vals = per_fly[per_fly[\"male\"] == \"naive\"][\"median_distance_px\"]\n",
"\n",
"stat, pvalue = stats.mannwhitneyu(trained_vals, naive_vals, alternative=\"two-sided\")\n",
"\n",
"print(f\"trained median: {trained_vals.median():.1f} px (n={len(trained_vals)})\")\n",
"print(f\"naive median: {naive_vals.median():.1f} px (n={len(naive_vals)})\")\n",
"print(f\"Mann-Whitney U: {stat:.0f} p-value: {pvalue:.4f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**How to read this**: the p-value is the probability of seeing a\n",
"difference at least this big *if there were really no difference*. By\n",
"convention p < 0.05 is \"interesting\", p < 0.01 is \"fairly convincing\".\n",
"But never trust a p-value without:\n",
"\n",
"1. eyeballing the histogram first (you did);\n",
"2. reporting the **effect size**, not just the p-value (e.g. the\n",
" difference of medians);\n",
"3. understanding that p-values\n",
" [say nothing about practical importance](https://www.nature.com/articles/d41586-019-00857-9).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What's next?\n",
"\n",
"- **Pick a different metric**: instead of median distance, try fraction\n",
" of time the flies were within 50 px (a \"close-proximity\" metric), or\n",
" the maximum velocity per fly. (Velocity needs identity tracking, which\n",
" is harder \u2014 see `flies_analysis_simple.ipynb` cell 16 for an example.)\n",
"- **Look at it per species**: re-run with `species == \"Sechellia\"` and\n",
" compare. Does the effect generalize? Where is it strongest?\n",
"- **Look at the bimodality**: a kernel density plot\n",
" ([seaborn.kdeplot](https://seaborn.pydata.org/generated/seaborn.kdeplot.html))\n",
" will show humps better than a histogram.\n",
"- **Time inside the session**: maybe the difference only shows up in the\n",
" first few minutes (right after the female is introduced). Slice\n",
" `per_frame` by `t` before aggregating.\n",
"- **Consult `docs/bimodal_hypothesis.md`**: it lays out a formal plan for\n",
" testing the \"some flies learn, others don't\" hypothesis.\n",
"\n",
"When you write your own analysis, **save it as a new notebook** (don't\n",
"edit this one). Copy the setup cells, change the question, change the\n",
"plot. That's how analysis projects grow.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A note on iteration speed\n",
"\n",
"The pipeline above is correct but **slow** because we apply a Python\n",
"function to every (track, t) group. If you find yourself re-running the\n",
"same expensive computation a lot, save the intermediate result to disk:\n",
"\n",
"```python\n",
"per_frame.to_parquet(\"per_frame_distance.parquet\")\n",
"# next time:\n",
"per_frame = pd.read_parquet(\"per_frame_distance.parquet\")\n",
"```\n",
"\n",
"`parquet` is a fast columnar format. `pip install pyarrow` if your\n",
"environment doesn't have it.\n",
"\n",
"There are also vectorized ways to compute these distances ~100\u00d7 faster\n",
"that avoid `groupby().apply()`. Don't worry about that yet \u2014 get a\n",
"correct answer first, optimize only if you find yourself waiting.\n"
]
}
]
}