from svg_sampler import sample_from_svg
from matplotlib import pyplot as plt
import jax
import jax.numpy as jnp
from eventax.evnn import FFEvNN
from eventax.neuron_models import LIF, QIF
import equinox as eqx
import optax
import numpy as np
import matplotlib.pyplot as plt
import optimistix as opx
import math
jax.config.update("jax_enable_x64", True)
Classification example¶
This example implements the Yin–Yang experiment from the paper Event-Based Backpropagation Can Compute Exact Gradients for Spiking Neural Networks by Wunderlich and Pehle.
The Yin–Yang dataset is a 2D classification problem where, based on an x and y coordinate, one of three classes must be predicted.
Data points are randomly sampled from the Yin–Yang structure shown below:
Following the paper, we temporally encode the dataset using spikes as follows.
Each sample produces five input spikes, one per input channel, with spike times:
$$ \begin{aligned} t_0 &= X \cdot t_\text{in max}\\ t_1 &= Y \cdot t_\text{in max}\\ t_2 &= (1 - X) \cdot t_\text{in max} \\ t_3 &= (1 - Y) \cdot t_\text{in max} \\ t_4 &= 0 \end{aligned} $$
where $t_\text{in max}$ is the maximum spike time for an input spike that determines the temporal scaling. We will also only simulate until $t_\text{in max}$.
Loss Function¶
We use the same cross entropy loss as in the paper
$$ \mathcal{L} = -\frac{1}{N_{\text{batch}}} \sum_{i=1}^{N_{\text{batch}}} \bigg[ \log \bigg[ \frac{\exp\!\big(-t^{\text{post}}_{i,l(i)} / \tau_0\big)} {\sum_{k=1}^{3} \exp\!\big(-t^{\text{post}}_{i,k} / \tau_0\big)} \bigg] + \alpha \bigg[ \exp\!\bigg(\frac{t^{\text{post}}_{i,l(i)}}{\tau_1}\bigg) - 1 \bigg] \bigg], $$
where $ t_{i, k}^{\text{post}}$ is the first spike time of neuron $k$ for the $i$th sample, $l(i)$ is the index of the correct label for the $i$th sample, $N_\text{batch}$ is the number of samples in a given batch and $\tau_0$ and $\tau_1$ are hyperparameters of the loss function. The first term corresponds to a cross-entropy loss function over the softmax function applied to the negative spike times (negative spike times as the class assignment is determined by the smallest spike time) and encourages an increase of the spike time difference between the label neuron and all other neurons. As the first term depends only on the relative spike times, the second term is a regularization term that encourages early spiking of the label neuron.
tau0, tau1, alpha = 0.5, 6.4, 3e-3
def loss_fn(model, spike_times_b, target_classes_b):
def per_example(spike_times_single, target_class):
out_times = model.ttfs(spike_times_single)
out_times = jnp.where(jnp.isinf(out_times), max_in_time, out_times) # treat inf (no spike) as max_in_time for loss
time_of_target = out_times[target_class]
loss1 = jnp.exp(-time_of_target / tau0) / jnp.sum(jnp.exp(-out_times / tau0))
loss2 = jnp.exp(time_of_target / tau1) - 1
return -jnp.log(loss1 + alpha * loss2)
per_ex_losses = jax.vmap(per_example)(spike_times_b, target_classes_b)
return jnp.mean(per_ex_losses)
v_and_grad = eqx.filter_value_and_grad(loss_fn)
@eqx.filter_jit
def train_step(model, opt_state, in_times_b, target_classes_b):
loss, grad = v_and_grad(model, in_times_b, target_classes_b)
updates, opt_state = optimizer.update(
grad, opt_state, params=eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
@eqx.filter_jit
def predict_batch(model, in_times_b):
return jax.vmap(model.ttfs)(in_times_b)
EvNN creation¶
We use a Feed forward EvNN with 20 layers and QIF neurons. As optimizer we choose Adam with a learning rate of $0.05$.
key = jax.random.PRNGKey(1234)
snn = FFEvNN(
dtype=jnp.float64,
layers=[20, 3],
in_size=5,
neuron_model=QIF,
max_solver_time=max_in_time,
key=key,
solver_stepsize=0.05,
tsyn=5.0,
tmem=20.0,
blim=1.0,
wlim=40.0,
wmean=20.0,
max_event_steps=200,
)
optimizer = optax.adam(0.05)
optim_state = optimizer.init(snn)
Training Loop¶
We train the network over 100 epochs with a batch size of 256.
batch_size = 256
num_epochs = 100
loss_hist = []
acc_hist = []
N = data.shape[0]
steps_per_epoch = math.ceil(N / batch_size)
rng = key
figs = make_train_figures(n_outputs=3)
for epoch in range(1, num_epochs + 1):
rng, sk = jax.random.split(rng)
perm = jax.random.permutation(sk, N)
X_shuf = data[perm]
y_shuf = targets[perm]
epoch_loss = 0.0
for start in range(0, N, batch_size):
xb = X_shuf[start:start + batch_size]
yb = y_shuf[start:start + batch_size]
snn, optim_state, batch_loss = train_step(snn, optim_state, xb, yb)
epoch_loss += float(batch_loss)
loss_hist.append(epoch_loss / steps_per_epoch)
out = predict_batch(snn, data_val)
pred = jnp.argmin(out, axis=-1)
acc_hist.append(float(jnp.mean(pred == targets_val)))
update_train_figures(figs, loss_hist, acc_hist, data_val[:, :2], out)