Leaky Integrate and Fire (LIF)¶
Implementation of the Leaky Integrate and Fire (LIF) neuron model.
The model follows the following internal dynamics:
Free Dynamics
Transition Condition
Jumps at Transition
Where
- \(I \in \mathbb{R}^N\) and \(V \in \mathbb{R}^N\) are the synaptic current and membrane potential of the neurons,
- \(I_c \in \mathbb{R}^N\) is the bias current,
- \(\tau_\text{syn}\) and \(\tau_\text{mem}\) are the time constants of those channels respectively,
- \(\vartheta\) is the threshold potential,
- \(W \in \mathbb{R}^{(N+K) \times N}\) is the synaptic weight matrix with number of neurons \(N\) and input size \(K\),
Times immediately before a jump are marked with \(-\), and immediately after with \(+\).
eventax.neuron_models.LIF ¶
Bases: NeuronModel
Leaky integrate-and-fire neuron with current-based synapse.
Two state channels: membrane voltage \(v\) and synaptic current \(i\).
Source code in eventax/neuron_models/lif.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
Parameters¶
All neuron-level parameters (thresh, tmem, tsyn, vreset) accept either a scalar (shared across all neurons) or an array of shape (n_neurons,) for per-neuron values.
| Parameter | Description | Default |
|---|---|---|
n_neurons |
Number of neurons in the layer. | — |
in_size |
Number of input connections. | — |
thresh |
Spike threshold \(\vartheta\). | — |
tmem |
Membrane time constant \(\tau_\text{mem}\). | — |
tsyn |
Synaptic time constant \(\tau_\text{syn}\). | — |
vreset |
Reset voltage \(V_\text{reset}\). | 0 |
wmask |
Binary mask for the weight matrix, shape \((N+K) \times N\). | — |
dtype |
JAX dtype for all computations. | jnp.float32 |
Weight initialisation¶
Weights and biases are initialised via [init_weights_and_bias][eventax.neuron_models.initializations.init_weights_and_bias].
| Parameter | Description | Default |
|---|---|---|
init_weights |
If given, used directly as \(W\). | None |
wmean |
Mean for random weight initialisation. | 0.0 |
wlim |
Uniform half-range for random weights. | None |
fan_in_mode |
Fan-in scaling mode. | None |
init_bias |
If given, used directly as \(I_c\). | None |
bmean |
Mean for random bias initialisation. | 0.0 |
blim |
Uniform half-range for random biases. | None |
Gradient behaviour¶
| Parameter | Description | Default |
|---|---|---|
reset_grad_preserve |
If True, the reset is implemented as \(V = V - (\vartheta - V_\text{reset})\), which preserves gradients through the reset. If False, uses jnp.where which blocks the gradient. |
True |
State layout¶
The state array y has shape (n_neurons, 2):
| Channel | Index | Description |
|---|---|---|
| \(V\) | y[:, 0] |
Membrane voltage |
| \(I\) | y[:, 1] |
Synaptic current |
Initial state is all zeros.
Trainable fields¶
| Field | Shape | Description |
|---|---|---|
weights |
\((N+K) \times N\) | Connection weight matrix \(W\) |
ic |
\((N,)\) | Bias current \(I_c\) |
All other fields are static (not optimised by gradient-based training).
Methods¶
init_state¶
Return zero-initialised state of shape (n_neurons, 2).
Source code in eventax/neuron_models/lif.py
97 98 99 | |
Returns zeros of shape (n_neurons, 2).
dynamics¶
Compute the LIF ODE derivatives for voltage and synaptic current.
Source code in eventax/neuron_models/lif.py
101 102 103 104 105 106 107 108 109 110 111 112 | |
Implements the free dynamics ODE above.
spike_condition¶
Return v - thresh; sign change triggers a spike.
Source code in eventax/neuron_models/lif.py
114 115 116 117 118 119 120 121 | |
Returns \(V - \vartheta\). A spike is triggered when this changes sign (crosses zero from below).
input_spike¶
Add connection weights to the synaptic current of target neurons.
Source code in eventax/neuron_models/lif.py
123 124 125 126 127 128 129 130 131 132 | |
Adds the connection weights \(W_{n,m}\) to the synaptic current channel of the target neurons. Only entries where valid_mask is True are applied.
reset_spiked¶
Reset voltage of spiked neurons and clip to prevent re-triggering.
Source code in eventax/neuron_models/lif.py
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
Resets the voltage of spiked neurons. The behaviour depends on reset_grad_preserve:
True(default): \(V = V - (\vartheta - V_\text{reset})\). Preserves the gradient.False: \(V = V_\text{reset}\) viajnp.where, which blocks the gradient through the reset.