SparseGraph

experimental.network_dynamics.graph.SparseGraph(
    weights,
    region_labels=None,
    threshold=0.0,
)

Sparse graph representation using JAX BCOO format.

Stores only non-zero weights for memory efficiency. Suitable for large networks with sparse connectivity (e.g., < 30% density).

Args: weights: Sparse weight matrix (BCOO) or dense array (will be sparsified) region_labels: Optional sequence of region labels (list, tuple, or array). If None, defaults to [‘Region_0’, ‘Region_1’, …] threshold: Values with absolute value below this are treated as zero

Example: >>> # From dense >>> dense_weights = jnp.array([[0, 0.5, 0], [0.3, 0, 0], [0, 0.2, 0]]) >>> graph = SparseGraph(dense_weights) >>> >>> # From COO format >>> data = jnp.array([0.5, 0.3, 0.2]) >>> row = jnp.array([0, 1, 2]) >>> col = jnp.array([1, 0, 1]) >>> graph = SparseGraph.from_coo(data, row, col, shape=(3, 3)) >>> >>> # From dense graph >>> from network_dynamics.graph.base import DenseGraph >>> dense_graph = DenseGraph(dense_weights) >>> sparse_graph = SparseGraph.from_dense(dense_graph, threshold=1e-10)

Attributes

Name Description
n_nodes Number of nodes in the network.
nnz Number of non-zero elements in weight matrix.
region_labels Labels for each node/region in the network.
sparsity Fraction of non-zero connections (excluding diagonal).
symmetric Check if the graph is symmetric (undirected).
weights Sparse weight matrix in BCOO format.

Methods

Name Description
from_coo Create sparse graph from COO format.
from_dense Convert dense graph to sparse.
plot Plot sparse connectivity matrix and weight distribution.
random Create a random sparse graph with brain-like connectivity.
todense Convert sparse graph to dense array.
tree_flatten Flatten for JAX PyTree.
tree_unflatten Unflatten from JAX PyTree.
verify Verify graph structure and properties.

from_coo

experimental.network_dynamics.graph.SparseGraph.from_coo(data, row, col, shape)

Create sparse graph from COO format.

Args: data: Non-zero weight values [nnz] row: Row indices [nnz] col: Column indices [nnz] shape: Matrix shape (n_nodes, n_nodes)

Returns: SparseGraph with specified connectivity

Example: >>> # Triangle graph: 0->1, 1->2, 2->0 >>> data = jnp.array([0.5, 0.3, 0.2]) >>> row = jnp.array([0, 1, 2]) >>> col = jnp.array([1, 2, 0]) >>> graph = SparseGraph.from_coo(data, row, col, shape=(3, 3))

from_dense

experimental.network_dynamics.graph.SparseGraph.from_dense(
    graph,
    threshold=1e-10,
)

Convert dense graph to sparse.

Args: graph: Dense graph to convert threshold: Set values with |weight| < threshold to zero

Returns: SparseGraph with same connectivity (zeroed below threshold)

plot

experimental.network_dynamics.graph.SparseGraph.plot(
    log_scale_weights=False,
    figsize=(12, 5),
)

Plot sparse connectivity matrix and weight distribution.

Args: log_scale_weights: If True, log-transform weights before plotting (helps reveal structure) figsize: Figure size (width, height)

Returns: fig, axes: Matplotlib figure and axes

Note: Converts sparse matrix to dense for visualization. Zeros are shown as white (background).

random

experimental.network_dynamics.graph.SparseGraph.random(
    n_nodes,
    sparsity=0.7,
    symmetric=True,
    weight_dist='lognormal',
    allow_self_loops=False,
    key=None,
)

Create a random sparse graph with brain-like connectivity.

Args: n_nodes: Number of nodes in the network sparsity: Fraction of connections present (0.7 = 70% dense) symmetric: Whether to create undirected (symmetric) connectivity weight_dist: Weight distribution (‘lognormal’, ‘uniform’, or ‘binary’) allow_self_loops: Whether to allow self-connections (diagonal) key: JAX random key (if None, creates one with seed 0)

Returns: SparseGraph with random connectivity

Example: >>> import jax >>> key = jax.random.key(42) >>> graph = SparseGraph.random(n_nodes=100, sparsity=0.3, key=key)

todense

experimental.network_dynamics.graph.SparseGraph.todense()

Convert sparse graph to dense array.

Warning: This creates a full n_nodes x n_nodes array. Use sparingly for large sparse graphs.

Returns: Dense weight matrix [n_nodes, n_nodes]

tree_flatten

experimental.network_dynamics.graph.SparseGraph.tree_flatten()

Flatten for JAX PyTree.

tree_unflatten

experimental.network_dynamics.graph.SparseGraph.tree_unflatten(
    aux_data,
    children,
)

Unflatten from JAX PyTree.

verify

experimental.network_dynamics.graph.SparseGraph.verify(verbose=True)

Verify graph structure and properties.

Args: verbose: Whether to print verification details

Returns: True if verification passes, False otherwise