---
title: "Reduced Wong Wang BOLD FC Optimization"
format:
html:
code-fold: false
toc: true
echo: false
jupyter: python3
---
```{python}
#| output: false
#| code-fold: true
#| code-summary: "Imports"
#| echo: true
# Set up environment
import os
import time
# Mock devices to force JAX to parallelize on CPU
cpu = True
if cpu:
N = 8
os.environ['XLA_FLAGS' ] = f'--xla_force_host_platform_device_count= { N} '
# Import all required libraries
from scipy import io
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import copy
import optax
# Import from tvboptim
from tvboptim import jaxify
from tvboptim.types import Parameter, GridSpace
from tvboptim.types.stateutils import show_free_parameters
from tvboptim.utils import set_cache_path, cache
from tvboptim import observation as obs
from tvboptim.execution import ParallelExecution, SequentialExecution
from tvboptim.optim.optax import OptaxOptimizer
from tvboptim.optim.callbacks import MultiCallback, DefaultPrintCallback, SavingCallback
# Import from tvbo
from tvbo.export.experiment import SimulationExperiment
from tvbo.datamodel import tvbo_datamodel
from tvbo.utils import numbered_print
# Set cache path for tvboptim
set_cache_path("./example_cache_rww" )
```
## Create a TVB-O Simulation Experiment
```{python}
#| echo: true
#| output: false
experiment = SimulationExperiment(
model = {
"name" : "ReducedWongWang" ,
"parameters" : {
"w" : {"name" : "w" , "value" : 0.5 },
"I_o" : {"name" : "I_o" , "value" : 0.32 },
}
},
connectivity = {
"parcellation" : {"atlas" : {"name" : "DesikanKilliany" }},
"conduction_speed" : {"name" : "cs" , "value" : np.array([np.inf])}
},
coupling = {
"name" : "Linear" ,
"parameters" : {"a" : {"name" : "a" , "value" : 0.5 }}
},
integration= {
"method" : "Heun" ,
"step_size" : 4.0 ,
"noise" : {"parameters" : {"sigma" : {'value' : 0.00283 }}},
"duration" : 2 * 60_000
},
monitors= {
"Raw" : {"name" : "Raw" },
"Bold" : {"name" : "Bold" , "period" : 1000.0 }},
)
```
```{python}
experiment.configure()
data = io.loadmat(
"../data/avgSC_DK.mat" ,
squeeze_me= True ,
mat_dtype= False ,
chars_as_strings= True ,
)
fcdata = io.loadmat( "../data/avgFC_DK.mat" , squeeze_me= True , mat_dtype= False , chars_as_strings= True , )
fc_target = fcdata["avgFC" ]
weights = np.array(data["SC_avg_weights" ], dtype= np.float64)
lengths = np.array(data["SC_avg_dists" ], dtype= np.float64)
experiment.connectivity.lengths = lengths
experiment.connectivity.weights = weights
experiment.connectivity.metadata.number_of_regions = weights.shape[1 ]
experiment.connectivity.normalize_weights()
# tvbo_datamodel.Connectome(weights = weights, lengths = lengths, conduction_speed = np.array([np.inf]))
```
```{python}
fig, (ax1, ax2) = plt.subplots(1 , 2 , figsize = (8 , 4 ), sharey= True )
im1 = ax1.imshow(experiment.connectivity.weights, cmap = "cividis" , vmax = .5 )
ax1.set_title("Structural Connectivity" )
ax1.set_xlabel("Region" )
ax1.set_ylabel("Region" )
cbar1 = fig.colorbar(im1, ax= ax1, shrink= 0.74 , label= "Connection Strength [a.u.]" , extend= 'max' )
im2 = ax2.imshow(experiment.connectivity.lengths, cmap = "cividis" )
ax2.set_title("Tract Lengths" )
ax2.set_xlabel("Region" )
cbar2 = fig.colorbar(im2, ax= ax2, shrink= 0.74 , label= "Tract Length [mm]" )
fig.dpi = 200
```
## Model Functions
```{python}
# Get model and state - scalar_pre is used to improve performance as we have no delay
#| echo: true
model, state = jaxify(experiment, scalar_pre = True )
```
::: {.callout-note collapse="true"}
## Rendered JAX Code
```{python}
#| echo: true
numbered_print(experiment.render_code(format = "jax" , scalar_pre = True ))
```
:::
## Run Initial Simulation - Update Inital Conditions
```{python}
#| echo: true
# Run the model and get results
result = model(state)
# Use first result as initial conditions for second run
state.initial_conditions = result[0 ]
# select last 5000 steps as BOLD stock
state.monitor_parameters[1 ]["stock" ] = result[0 ].data[- 5000 :]
result2 = model(state)
```
```{python}
from matplotlib.colors import Normalize
# Plot time series from both simulations
fig, (ax1, ax2) = plt.subplots(1 , 2 , figsize= (6 , 2 ))
t_max = 1000
# For the first plot (Raw)
time1 = result2[0 ].time[0 :t_max]
data1 = result2[0 ].data[0 :t_max,0 ,:,0 ]
# If data1 is multi-dimensional, we need to handle each line separately
num_lines = data1.shape[1 ]
# Create a colormap with distinct colors for each line based on its mean
cmap = plt.cm.cividis
mean_values = np.mean(data1, axis= 0 ) # Mean for each line
norm = Normalize(vmin= np.min (mean_values), vmax= np.max (mean_values))
for i in range (num_lines):
color = cmap(norm(mean_values[i]))
ax1.plot(time1, data1[:,i], color= color, linewidth= 0.5 )
# Add title as text inside the plot
ax1.text(0.95 , 0.95 , "Raw" , transform= ax1.transAxes, fontsize= 12 ,
ha= 'right' , va= 'top' , bbox= dict (boxstyle= "round,pad=0.3" , facecolor= 'white' , alpha= 0.8 ))
ax1.set_xlabel("Time [ms]" )
ax1.set_ylabel("S [a.u.]" ) # Add y-label as requested
# For the second plot (Bold) - similar approach
time2 = result2[1 ].time[0 :60 ]
data2 = result2[1 ].data[0 :60 ,0 ,:,0 ]
# If data2 is multi-dimensional, handle each line separately
num_lines = data2.shape[1 ]
# Create a colormap with distinct colors for each line based on its mean
# cmap = plt.cm.viridis
mean_values = np.mean(data2, axis= 0 )
norm = Normalize(vmin= np.min (mean_values), vmax= np.max (mean_values))
for i in range (num_lines):
color = cmap(norm(mean_values[i]))
ax2.plot(time2, data2[:,i], color= color, linewidth= 0.8 )
# Add title as text inside the plot
ax2.text(0.95 , 0.95 , "Bold" , transform= ax2.transAxes, fontsize= 12 ,
ha= 'right' , va= 'top' , bbox= dict (boxstyle= "round,pad=0.3" , facecolor= 'white' , alpha= 0.8 ))
ax2.set_xlabel("Time [ms]" )
# ax2.set_ylabel("Value [a.u.]") # Add y-label as requested
fig.set_dpi(500 )
# plt.tight_layout()
plt.show()
```
## Define Observations and Loss
```{python}
#| echo: true
def observation(state):
bold = model(state)[1 ]
return obs.fc(bold, skip_t = 20 )
def loss(state):
fc = observation(state)
# return 1 - obs.fc_corr(fc, fc_target)
return obs.rmse(fc, fc_target)
```
```{python}
# Calculate the functional connectivity matrix
fc_ = np.array(obs.fc(result2[1 ], skip_t = 20 ))
# Create the figure and axis
fig, (ax2, ax) = plt.subplots(1 ,2 , figsize= (6 , 2 ))
# Plot the FC matrix
for ax_current, fc_matrix, title in zip ([ax2, ax], [fc_target, fc_], ["Target FC" , "Observed FC \n r = {:.3f} " .format (obs.fc_corr(fc_, fc_target))]):
fc_matrix = np.copy(fc_matrix)
np.fill_diagonal(fc_matrix, np.nan) # Set diagonal to NaN to handle them separately
im = ax_current.imshow(fc_matrix, cmap= 'cividis' , vmax = 0.9 )
ax_current.set_xticks([]) # Remove x-axis ticks
ax_current.set_yticks([]) # Remove y-axis ticks
ax_current.set_xlabel('' ) # Remove x-axis label
ax_current.set_ylabel('' )
# Add title as annotation inside the plot
ax_current.annotate(title,
xy= (0.3 , 0.95 ), # Position in axes coordinates
xycoords= 'axes fraction' ,
ha= 'left' , va= 'top' ,
fontsize= 10 , fontweight= 'bold' ,
color= 'black' , # Black text on white background
bbox= dict (boxstyle= 'round,pad=0.3' ,
facecolor= 'white' , alpha= 0.9 ))
# Add colorbar (uncomment if needed)
# cbar = plt.colorbar(im, ax=ax)
# cbar.set_label("Correlation")
plt.tight_layout()
fig.set_dpi(300 )
```
```{python}
#| output: false
#
import matplotlib.transforms as transforms
# Calculate the functional connectivity matrix (keep this line as reference)
# fc_ = np.array(fc(result2[1], skip_t = 20))
# Create the figure and axis
fig, ax = plt.subplots(figsize= (3 , 3 ))
# Create a mask for the upper triangle
mask = np.triu(np.ones_like(fc_target), k= 1 ).astype(bool )
# Apply mask to the FC matrix (create a copy first)
fc_matrix = np.copy(fc_target)
fc_matrix[mask] = np.nan # Set upper triangle to NaN
np.fill_diagonal(fc_matrix, np.nan) # Set diagonal to NaN
import matplotlib.colors as mcolors
jax_colors = [
'#9C27B0' , # Purple/Violet
'#00BCD4' , # Cyan/Teal (green-blue)
'#2196F3' , # Blue
]
jax_cmap = mcolors.LinearSegmentedColormap.from_list(
'jax' , jax_colors, N= 256
)
# Plot the FC matrix
im = ax.imshow(fc_matrix, cmap= jax_cmap, vmax= 0.9 )
# Remove axes ticks and labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('' )
ax.set_ylabel('' )
# Get the dimensions of the matrix for text positioning
n = fc_target.shape[0 ]
# Add title text in the upper triangle area
# Calculate position for the text (center of upper triangle)
# text_x = n * 0.75
# text_y = n * 0.25
# ax.text(text_x, text_y, "Target FC",
# fontsize=12, fontweight='bold', ha='center', va='center')
# Add a thin border around the plot for clarity
for spine in ax.spines.values():
spine.set_visible(False )
# Set DPI and layout
fig.set_dpi(300 )
plt.tight_layout()
```
## Parameter Exploration
```{python}
#| echo: true
#| output: false
# Set up parameter ranges for exploration
state.parameters.model.w.free = True
state.parameters.model.w.low = 0.001
state.parameters.model.w.high = 0.7
state.parameters.coupling.a.free = True
state.parameters.coupling.a.low = 0.001
state.parameters.coupling.a.high = 0.7
show_free_parameters(state)
# Create grid for parameter exploration
# n = 32
n = 64
# _params = copy.deepcopy(state)
# _params.nt = 10_000 # 10s simulation for better frequency resolution
params_set = GridSpace(state, n= n)
@cache ("explore" , redo = False )
def explore():
exec = ParallelExecution(loss, params_set, n_pmap= 8 )
# Alternative: Sequential execution
# exec = SequentialExecution(loss, params_set)
return exec .run()
results = explore()
```
```{python}
# Prepare data for visualization
pc = params_set.collect()
a = pc.parameters.coupling.a.value
b = pc.parameters.model.w.value
# Get parameter ranges
a_min, a_max = min (a), max (a)
b_min, b_max = min (b), max (b)
# Create figure and axis
fig, ax = plt.subplots(figsize= (5 , 2.5 ))
# Create the heatmap
im = ax.imshow(jnp.stack(results).reshape(n, n).T,
cmap= 'cividis_r' ,
extent= [a_min, a_max, b_min, b_max],
origin= 'lower' ,
aspect= 'auto' ,
# interpolation='hamming')
interpolation= 'none' )
# Add colorbar and labels
cbar = plt.colorbar(im, label= "Loss" )
ax.set_xlabel('G' )
ax.set_ylabel('w' )
# ax.set_title("Exploration")
fig.set_dpi(300 )
plt.tight_layout()
```
```{python}
# opt = OptaxOptimizer(loss, optax.adagrad(0.01))
# optimized_state = opt.run(params, max_steps=100)
```
## Run Optimization
```{python}
#| echo: true
#| output: false
# Create and run optimizer
cb = MultiCallback([
DefaultPrintCallback(every= 10 ),
SavingCallback(key = "state" , save_fun = lambda * args: args[1 ]) # save updated state
])
@cache ("optimize" , redo = False )
def optimize():
opt = OptaxOptimizer(loss, optax.adam(0.01 , b2 = 0.9999 ), callback = cb)
fitted_state, fitting_data = opt.run(state, max_steps= 400 )
return fitted_state, fitting_data
fitted_state, fitting_data = optimize()
```
```{python}
import matplotlib.patheffects as path_effects
# Prepare data for visualization
pc = params_set.collect()
a = pc.parameters.coupling.a.value
b = pc.parameters.model.w.value
# Get parameter ranges
a_min, a_max = min (a), max (a)
b_min, b_max = min (b), max (b)
# Create figure and axis
fig, ax = plt.subplots(figsize= (5 , 2.5 ))
# Create the heatmap
im = ax.imshow((jnp.stack(results).reshape(n, n).T),
cmap= 'cividis_r' ,
extent= [a_min, a_max, b_min, b_max],
origin= 'lower' ,
aspect= 'auto' ,
interpolation= 'none' )
# interpolation='hamming')
# Mark initial value
a_init = state.parameters.coupling.a.value
b_init = state.parameters.model.w.value
ax.scatter(a_init, b_init, color= 'white' , s= 100 , marker= 'o' ,
edgecolors= 'k' , linewidths= 1 , zorder= 5 )
# Add annotation
ax.annotate('Initial Value' , xy= (a_init, b_init),
xytext= (a_init, b_init+ 0.05 * (b_max- b_min)),
color= 'white' , fontweight= 'bold' , ha= 'center' , zorder= 5 ,
path_effects= [path_effects.withStroke(linewidth= 3 , foreground= 'black' )])
# Add fitted value point
a_fit = fitted_state.parameters.coupling.a.value
b_fit = fitted_state.parameters.model.w.value
ax.scatter(a_fit, b_fit, color= 'white' , s= 100 , marker= 'o' ,
edgecolors= 'k' , linewidths= 1 , zorder= 5 )
# Add annotation for the fitted value
ax.annotate('Optimized Value' , xy= (a_fit, b_fit),
xytext= (a_fit, b_fit- 0.1 * (b_max- b_min)),
color= 'white' , fontweight= 'bold' , ha= 'center' , zorder= 5 ,
path_effects= [path_effects.withStroke(linewidth= 3 , foreground= 'black' )])
# Add optimization path points
a_route = np.array([ds.parameters.coupling.a.value for ds in fitting_data["state" ].save])
b_route = np.array([ds.parameters.model.w.value for ds in fitting_data["state" ].save])
ax.scatter(a_route[::2 ], b_route[::2 ], color= 'white' , s= 15 , marker= 'o' ,
linewidths= 1 , zorder= 4 , edgecolors= 'k' )
# Add title
# plt.title('Optimization')
# Remove axes ticks and labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('' )
ax.set_ylabel('' )
# Add horizontal colorbar at the bottom
# cbar = plt.colorbar(im, label="Loss")
plt.tight_layout() # Adjust layout to make room for colorbar
fig.set_dpi(500 )
```
## Refine Optimization by setting Regional Parameters
```{python}
#| echo: true
# Copy already optimized state and turn parameters regional
_fitted_state = copy.deepcopy(fitted_state)
_fitted_state.parameters.model.w.free = True
_fitted_state.parameters.model.w.value = jnp.broadcast_to(_fitted_state.parameters.model.w.value, (84 ,1 ))
_fitted_state.parameters.model.I_o.free = True
_fitted_state.parameters.model.I_o.value = jnp.broadcast_to(_fitted_state.parameters.model.I_o.value, (84 ,1 ))
_fitted_state.parameters.coupling.a.free = False
```
```{python}
#| echo: true
@cache ("optimize_het" , redo = False )
def optimize():
opt = OptaxOptimizer(loss, optax.adam(0.004 , b2 = 0.999 ), callback= cb)
fitted_state, fitting_data = opt.run(_fitted_state, max_steps= 200 )
return fitted_state, fitting_data
fitted_state_het, fitting_data_het = optimize()
```
```{python}
#| output: false
# plt.imshow(observation(fitted_state_het))
degree = np.sum ((fitted_state_het.connectivity.weights.value > 0.001 ),axis = 1 )
# plt.scatter(degree, fitted_state_het.parameters.model.w)
plt.scatter(fitted_state_het.parameters.model.I_o.value, fitted_state_het.parameters.model.w.value)
```
```{python}
# optimized_state.parameters.model.w.broadcast_to((84,))
# opt = OptaxOptimizer(loss, optax.adam(0.0002))
# optimized_state_regional = opt.run(optimized_state, max_steps=200)
```
```{python}
#| output: false
fitted_state.nt = 75000
fitted_state_het.nt = 75000
fc_hom = np.array(observation(fitted_state))
fc_het = np.array(observation(fitted_state_het))
```
```{python}
#| output: true
# Create the figure and axis
fig, (ax1, ax2) = plt.subplots(1 ,2 , figsize= (6 , 2.5 ))
# Plot the FC matrix
for ax_current, fc_matrix, title_prefix in zip ([ax1, ax2], [fc_hom, fc_het], ["Global Parameters" , "Regional Parameters" ]):
fc_matrix = np.copy(fc_matrix)
np.fill_diagonal(fc_matrix, np.nan) # Set diagonal to NaN to handle them separately
im = ax_current.imshow(fc_matrix, cmap= 'cividis' , vmax = 1.0 )
# plt.colorbar(im, ax=ax)
# Set axis labels
# ax.set_xlabel("Region")
# ax.set_ylabel("Region")
# # Set x and y tick labels (from 1 to 84)
# num_regions = fc_matrix.shape[0]
# tick_positions = np.arange(0, num_regions, 10) # Place ticks every 10 regions
# tick_labels = [str(i+1) for i in tick_positions] # Labels starting from 1
# ax.axis('off')
ax_current.set_xticks([]) # Remove x-axis ticks
ax_current.set_yticks([]) # Remove y-axis ticks
ax_current.set_xlabel('' ) # Remove x-axis label
ax_current.set_ylabel('' )
# ax.set_xticks(tick_positions)
# ax.set_xticklabels(tick_labels)
# ax.set_yticks(tick_positions)
# ax.set_yticklabels(tick_labels)
#
# Add colorbar
# cbar = plt.colorbar(im, ax=ax)
# cbar.set_label("Correlation")
# Calculate correlation for title
if title_prefix == "Global Parameters" :
corr_value = obs.fc_corr(fc_hom, fc_target)
else :
corr_value = obs.fc_corr(fc_het, fc_target)
# Add title as annotation inside the plot
title = f" { title_prefix} \n r = { corr_value:.3f} "
ax_current.annotate(title,
xy= (0.23 , 0.95 ), # Position in axes coordinates
xycoords= 'axes fraction' ,
ha= 'left' , va= 'top' ,
fontsize= 10 , fontweight= 'bold' ,
color= 'black' , # Black text on white background
bbox= dict (boxstyle= 'round,pad=0.3' ,
facecolor= 'white' , alpha= 0.9 ))
# Add title
# ax1.set_title(f"Global r = {rmse(fc_hom, fc_target):.3f}")
# ax1.set_title(f"Global Parameters r = {obs.fc_corr(fc_hom, fc_target):.3f}")
# ax2.set_title(f"Regional r = {rmse(fc_het, fc_target):.3f}")
# ax2.set_title(f"Regional Parameters r = {obs.fc_corr(fc_het, fc_target):.3f}")
# plt.suptitle(f"Functional Connectivity r = {fc_corr(fc_target, fc_):.2f}")
fig.set_dpi(300 )
plt.tight_layout()
```