import pandas as pd import matplotlib.pyplot as plt import numpy as np from config import DATA_PROCESSED, DATA_METADATA, FIGURES # Load data trained_distances = pd.read_csv(DATA_PROCESSED / 'trained_distances.csv') untrained_distances = pd.read_csv(DATA_PROCESSED / 'untrained_distances.csv') barrier_data = pd.read_csv(DATA_METADATA / '2025_07_15_barrier_opening.csv') # Convert opening_time to milliseconds and create a mapping barrier_data['opening_time_ms'] = barrier_data['opening_time'] * 1000 opening_times = dict(zip(barrier_data['machine'], barrier_data['opening_time_ms'])) def align_to_opening_time(df, opening_times, max_time=300000): """Align distance data to barrier opening time. Args: df (pd.DataFrame): Distance data. opening_times (dict): Machine to opening time mapping. max_time (int): Maximum time in ms to include. Returns: pd.DataFrame: Aligned data filtered to +/-150s around opening. """ df_aligned = df.copy() df_aligned['aligned_time'] = np.nan for machine in df['machine_name'].unique(): if machine in opening_times: opening_time = opening_times[machine] mask = (df['machine_name'] == machine) & (df['t'] <= max_time) df_aligned.loc[mask, 'aligned_time'] = df.loc[mask, 't'] - opening_time df_aligned = df_aligned.dropna(subset=['aligned_time']) df_aligned = df_aligned[(df_aligned['aligned_time'] >= -150000) & (df_aligned['aligned_time'] <= 150000)] return df_aligned # Align the data trained_aligned = align_to_opening_time(trained_distances, opening_times) untrained_aligned = align_to_opening_time(untrained_distances, opening_times) # Calculate average distance over aligned time trained_avg = trained_aligned.groupby('aligned_time')['distance'].mean() untrained_avg = untrained_aligned.groupby('aligned_time')['distance'].mean() # Apply smoothing window_size = 50 trained_smooth = trained_avg.rolling(window=window_size, center=True).mean() untrained_smooth = untrained_avg.rolling(window=window_size, center=True).mean() # Create the plot plt.figure(figsize=(12, 6)) plt.plot(trained_smooth.index/1000, trained_smooth.values, label='Trained (smoothed)', color='blue', linewidth=2) plt.plot(untrained_smooth.index/1000, untrained_smooth.values, label='Untrained (smoothed)', color='red', linewidth=2) plt.axvline(x=0, color='black', linestyle='--', alpha=0.7, label='Barrier Opening') plt.xlabel('Time (seconds relative to barrier opening)') plt.ylabel('Average Distance') plt.title('Average Distance Between Flies Aligned to Barrier Opening Time') plt.legend() plt.grid(True, alpha=0.3) plt.xlim(-150, 150) plt.tight_layout() plt.savefig(FIGURES / 'avg_distance_aligned_to_opening.png', dpi=300, bbox_inches='tight') plt.show() # Print statistics print("Trained flies (aligned to barrier opening):") print(f" Data points: {len(trained_aligned)}") print(f" Mean distance: {trained_aligned['distance'].mean():.2f}") print(f" Std distance: {trained_aligned['distance'].std():.2f}") print("\nUntrained flies (aligned to barrier opening):") print(f" Data points: {len(untrained_aligned)}") print(f" Mean distance: {untrained_aligned['distance'].mean():.2f}") print(f" Std distance: {untrained_aligned['distance'].std():.2f}") # Pre/post analysis trained_pre = trained_aligned[trained_aligned['aligned_time'] < 0] trained_post = trained_aligned[trained_aligned['aligned_time'] > 0] untrained_pre = untrained_aligned[untrained_aligned['aligned_time'] < 0] untrained_post = untrained_aligned[untrained_aligned['aligned_time'] > 0] print("\nPre-opening period (t < 0):") print(f" Trained mean distance: {trained_pre['distance'].mean():.2f}") print(f" Untrained mean distance: {untrained_pre['distance'].mean():.2f}") print("\nPost-opening period (t > 0):") print(f" Trained mean distance: {trained_post['distance'].mean():.2f}") print(f" Untrained mean distance: {untrained_post['distance'].mean():.2f}")