import pandas as pd import matplotlib.pyplot as plt import numpy as np from config import DATA_PROCESSED, FIGURES # Load data trained_distances = pd.read_csv(DATA_PROCESSED / 'trained_distances.csv') untrained_distances = pd.read_csv(DATA_PROCESSED / 'untrained_distances.csv') # Remove NaN distances trained_clean = trained_distances.dropna(subset=['distance']) untrained_clean = untrained_distances.dropna(subset=['distance']) # Calculate average distance over time trained_avg = trained_clean.groupby('t')['distance'].mean() untrained_avg = untrained_clean.groupby('t')['distance'].mean() # Create the plot plt.figure(figsize=(12, 6)) plt.plot(trained_avg.index, trained_avg.values, label='Trained (avg)', color='blue', linewidth=1) plt.plot(untrained_avg.index, untrained_avg.values, label='Untrained (avg)', color='red', linewidth=1) plt.xlabel('Time') plt.ylabel('Average Distance') plt.title('Average Distance Between Flies Over Time by Group') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(FIGURES / 'avg_distance_over_time.png', dpi=300, bbox_inches='tight') plt.show() print("Trained flies:") print(f" Mean distance: {trained_clean['distance'].mean():.2f}") print(f" Std distance: {trained_clean['distance'].std():.2f}") print("\nUntrained flies:") print(f" Mean distance: {untrained_clean['distance'].mean():.2f}") print(f" Std distance: {untrained_clean['distance'].std():.2f}")