import pandas as pd import matplotlib.pyplot as plt import numpy as np from config import DATA_PROCESSED, FIGURES # Load data trained_distances = pd.read_csv(DATA_PROCESSED / 'trained_distances.csv') untrained_distances = pd.read_csv(DATA_PROCESSED / 'untrained_distances.csv') # Remove NaN distances and filter for first 200 seconds trained_clean = trained_distances.dropna(subset=['distance']) untrained_clean = untrained_distances.dropna(subset=['distance']) trained_filtered = trained_clean[trained_clean['t'] <= 200000] untrained_filtered = untrained_clean[untrained_clean['t'] <= 200000] # Calculate average distance over time trained_avg = trained_filtered.groupby('t')['distance'].mean() untrained_avg = untrained_filtered.groupby('t')['distance'].mean() # Apply smoothing window_size = 50 trained_smooth = trained_avg.rolling(window=window_size, center=True).mean() untrained_smooth = untrained_avg.rolling(window=window_size, center=True).mean() # Create the plot plt.figure(figsize=(12, 6)) plt.plot(trained_smooth.index/1000, trained_smooth.values, label='Trained (smoothed)', color='blue', linewidth=2) plt.plot(untrained_smooth.index/1000, untrained_smooth.values, label='Untrained (smoothed)', color='red', linewidth=2) plt.xlabel('Time (seconds)') plt.ylabel('Average Distance') plt.title('Average Distance Between Flies Over Time (First 200 Seconds)') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(FIGURES / 'avg_distance_over_time_first_200s.png', dpi=300, bbox_inches='tight') plt.show() print("Trained flies (first 200 seconds):") print(f" Mean distance: {trained_filtered['distance'].mean():.2f}") print(f" Std distance: {trained_filtered['distance'].std():.2f}") print("\nUntrained flies (first 200 seconds):") print(f" Mean distance: {untrained_filtered['distance'].mean():.2f}") print(f" Std distance: {untrained_filtered['distance'].std():.2f}")