import pandas as pd import matplotlib.pyplot as plt import numpy as np from config import DATA_PROCESSED, FIGURES # Load data trained_distances = pd.read_csv(DATA_PROCESSED / 'trained_distances.csv') untrained_distances = pd.read_csv(DATA_PROCESSED / 'untrained_distances.csv') # Remove NaN distances trained_clean = trained_distances.dropna(subset=['distance']) untrained_clean = untrained_distances.dropna(subset=['distance']) # Create the plot plt.figure(figsize=(12, 6)) # Sample 1000 points from each group to avoid overcrowding if len(trained_clean) > 1000: trained_sample = trained_clean.sample(1000, random_state=42) else: trained_sample = trained_clean if len(untrained_clean) > 1000: untrained_sample = untrained_clean.sample(1000, random_state=42) else: untrained_sample = untrained_clean plt.scatter(trained_sample['t'], trained_sample['distance'], alpha=0.5, s=1, label='Trained', color='blue') plt.scatter(untrained_sample['t'], untrained_sample['distance'], alpha=0.5, s=1, label='Untrained', color='red') plt.xlabel('Time') plt.ylabel('Distance') plt.title('Distance Between Flies Over Time') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(FIGURES / 'distance_over_time.png', dpi=300, bbox_inches='tight') plt.show() print("Trained flies:") print(f" Mean distance: {trained_clean['distance'].mean():.2f}") print(f" Std distance: {trained_clean['distance'].std():.2f}") print("\nUntrained flies:") print(f" Mean distance: {untrained_clean['distance'].mean():.2f}") print(f" Std distance: {untrained_clean['distance'].std():.2f}")