from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
def create_animation_gif(output_path='ei_tuning_animation.gif', fps=2):
"""Create an animated GIF showing the tuning progression."""
fig = plt.figure(figsize=(10, 8), layout = "tight")
def update_frame(snapshot_idx):
"""Update function for animation."""
fig.clear()
gs = fig.add_gridspec(3, 3, hspace=0.4, wspace=0.5)
iteration = snapshots['iterations'][snapshot_idx]
bold_sig = snapshots['bold_signal'][snapshot_idx]
raw_ts = snapshots['raw_timeseries'][snapshot_idx]
J_i_vals = snapshots['J_i'][snapshot_idx]
wLRE_mat = snapshots['wLRE'][snapshot_idx]
wFFI_mat = snapshots['wFFI'][snapshot_idx]
fc_pred_mat = snapshots['fc_pred'][snapshot_idx]
fc_corr_val = snapshots['fc_corr'][snapshot_idx]
fc_rmse_val = snapshots['fc_rmse'][snapshot_idx]
# Define harmonized colors for animation
cividis_cmap = plt.cm.cividis
anim_blue = cividis_cmap(0.3)
anim_gold = cividis_cmap(0.85)
anim_mid = cividis_cmap(0.6)
# Row 1: BOLD signal and raw timeseries
ax1 = fig.add_subplot(gs[0, 0])
# Plot all BOLD traces with cividis colors
n_bold_regions = bold_sig.shape[1]
colors_anim_bold = cividis_cmap(np.linspace(0.2, 0.9, n_bold_regions))
for i in range(n_bold_regions):
ax1.plot(bold_sig[:, i], alpha=0.3, linewidth=0.8, color=colors_anim_bold[i])
ax1.set_xlabel('BOLD time point (TR)')
ax1.set_ylabel('BOLD signal')
ax1.set_title('BOLD Signal (all regions)')
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 2)
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(raw_ts, alpha=0.5, linewidth=1, color=anim_blue)
ax2.axhline(target_fic, color=anim_gold, linestyle='--', linewidth=2)
ax2.set_xlabel('Time step')
ax2.set_ylabel('S_e')
ax2.set_title('Excitatory Activity (S_e)')
ax2.set_ylim(0, 1)
ax2.grid(True, alpha=0.3)
# Row 1, Col 3: J_i distribution
ax3 = fig.add_subplot(gs[0, 2])
ax3.bar(range(n_nodes), J_i_vals, alpha=0.7, color=anim_mid)
ax3.set_xlabel('Region')
ax3.set_ylabel('J_i')
ax3.set_title('Inhibitory Weights (J_i)')
ax3.grid(True, alpha=0.3, axis='y')
ax3.set_ylim(0, 2.2)
# Row 2: FC matrices - use cividis for FC, diverging for difference
ax4 = fig.add_subplot(gs[1, 0])
im4 = ax4.imshow(fc_target, vmin=0, vmax=1.0, cmap='cividis')
ax4.set_title('Target FC')
plt.colorbar(im4, ax=ax4, fraction=0.046)
ax5 = fig.add_subplot(gs[1, 1])
im5 = ax5.imshow(fc_pred_mat, vmin=0, vmax=1.0, cmap='cividis')
ax5.set_title(f'Predicted FC r = {fc_corr_val:.2f}')
plt.colorbar(im5, ax=ax5, fraction=0.046)
ax6 = fig.add_subplot(gs[1, 2])
fc_diff = fc_pred_mat - fc_target
im6 = ax6.imshow(fc_diff, vmin=-0.5, vmax=0.5, cmap='RdBu_r')
ax6.set_title('FC Difference')
plt.colorbar(im6, ax=ax6, fraction=0.046)
# Row 3: wLRE and wFFI matrices - use cividis
ax7 = fig.add_subplot(gs[2, 0])
im7 = ax7.imshow(wLRE_mat, vmin=0.8, vmax=2.2, cmap='cividis')
wLRE_corr = np.corrcoef(wLRE_mat.flatten(), fc_target.flatten())[0, 1]
ax7.set_title(f'wLRE (r={wLRE_corr:.3f})')
plt.colorbar(im7, ax=ax7, fraction=0.046)
ax8 = fig.add_subplot(gs[2, 1])
im8 = ax8.imshow(wFFI_mat, vmin=0, vmax=1.2, cmap='cividis')
wFFI_corr = np.corrcoef(wFFI_mat.flatten(), fc_target.flatten())[0, 1]
ax8.set_title(f'wFFI (r={wFFI_corr:.3f})')
plt.colorbar(im8, ax=ax8, fraction=0.046)
# Row 3, Col 3: Convergence trajectory with position marker
ax9 = fig.add_subplot(gs[2, 2])
ax9_twin = ax9.twinx()
# Plot full trajectories with harmonized colors
line1 = ax9.plot(fc_correlations, linewidth=1.5, alpha=0.7, color=anim_blue, label='FC Correlation')
line2 = ax9_twin.plot(fc_rmse_values, linewidth=1.5, alpha=0.7, color=anim_gold, label='FC RMSE')
# Add vertical line at current position
current_iter = snapshots['iterations'][snapshot_idx]
ax9.axvline(current_iter, color=anim_mid, linestyle='--', linewidth=2, alpha=0.8)
# Add marker at current position
ax9.plot(current_iter, fc_correlations[current_iter], 'o', color=anim_blue, markersize=8, zorder=5)
ax9_twin.plot(current_iter, fc_rmse_values[current_iter], 'o', color=anim_gold, markersize=8, zorder=5)
# Labels and styling
ax9.set_xlabel('Iteration')
ax9.set_ylabel('FC Correlation', color=anim_blue)
ax9.tick_params(axis='y', labelcolor=anim_blue)
ax9_twin.set_ylabel('FC RMSE', color=anim_gold)
ax9_twin.tick_params(axis='y', labelcolor=anim_gold)
ax9.set_title('Convergence Trajectory')
ax9.grid(True, alpha=0.3)
# Legend
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax9.legend(lines, labels, loc='center left', fontsize=8)
fig.suptitle(f'Iteration {iteration}/{n_eib_steps}', fontsize=16, fontweight='bold')
# Create animation with last frame pause
# Create frame sequence with last frame repeated to make it pause longer
n_frames = len(snapshots['iterations'])
last_frame_repeats = 5 # Repeat last frame 5 times (2.5 seconds at 2 fps)
frame_sequence = list(range(n_frames)) + [n_frames - 1] * last_frame_repeats
anim = FuncAnimation(
fig,
update_frame,
frames=frame_sequence,
interval=1000/fps, # milliseconds between frames
repeat=True
)
# Save as GIF
writer = PillowWriter(fps=fps)
anim.save(output_path, writer=writer)
plt.close()
print(f"Animation saved to {output_path}")
return output_path
# Create the GIF (may take a minute depending on number of snapshots)
gif_path = create_animation_gif('ei_tuning_animation.gif', fps=2)