Skip to content

Leaky Integrate and Fire (LIF)

Implementation of the Leaky Integrate and Fire (LIF) neuron model.
The model follows the following internal dynamics:

Free Dynamics

\[ \begin{aligned} \tau_\text{mem} \frac{\partial V}{\partial t} &= -V + I + I_c, \\ \tau_\text{syn} \frac{\partial I}{\partial t} &= -I. \end{aligned} \]

Transition Condition

\[ \begin{aligned} V_n(t_s^-) - \vartheta &= 0, \\ \dot{V}_n(t_s^-) &\neq 0, \\ \text{for any neuron } n. \end{aligned} \]

Jumps at Transition

\[ \begin{aligned} V_n(t_s^+) &= V_\text{reset}, \\ I_m\big(t_s^+\big) &= I_m\big(t_s^-\big) + W_{mn}, \quad \forall m. \end{aligned} \]

Where

  • \(I \in \mathbb{R}^N\) and \(V \in \mathbb{R}^N\) are the synaptic current and membrane potential of the neurons,
  • \(I_c \in \mathbb{R}^N\) is the bias current,
  • \(\tau_\text{syn}\) and \(\tau_\text{mem}\) are the time constants of those channels respectively,
  • \(\vartheta\) is the threshold potential,
  • \(W \in \mathbb{R}^{(N+K) \times N}\) is the synaptic weight matrix with number of neurons \(N\) and input size \(K\),

Times immediately before a jump are marked with \(-\), and immediately after with \(+\).

eventax.neuron_models.LIF

Bases: NeuronModel

Leaky integrate-and-fire neuron with current-based synapse.

Two state channels: membrane voltage \(v\) and synaptic current \(i\).

Source code in eventax/neuron_models/lif.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
class LIF(NeuronModel):
    """Leaky integrate-and-fire neuron with current-based synapse.

    Two state channels: membrane voltage $v$ and synaptic current $i$.
    """

    thresh: StaticArray = eqx.field(static=True)
    """Spike threshold $v_\\mathrm{th}$. Scalar or per-neuron."""

    tmem: StaticArray = eqx.field(static=True)
    """Membrane time constant $\\tau_\\mathrm{mem}$. Scalar or per-neuron."""

    tsyn: StaticArray = eqx.field(static=True)
    """Synaptic time constant $\\tau_\\mathrm{syn}$. Scalar or per-neuron."""

    vreset: StaticArray = eqx.field(static=True)
    """Reset voltage $v_\\mathrm{reset}$. Scalar or per-neuron."""

    epsilon: float = eqx.field(static=True)
    """Machine epsilon for the chosen dtype."""

    reset_grad_preserve: bool = eqx.field(static=True)
    """If `True`, reset uses subtraction to preserve gradients."""

    weights: Float[Array, "in_plus_neurons neurons"]
    """Connection weight matrix."""

    ic: Float[Array, "neurons"]
    """Learnable bias current $i_c$."""

    def __init__(
        self,
        key: PRNGKeyArray,
        n_neurons: int,
        in_size: int,
        wmask: Float[Array, "in_plus_neurons neurons"],
        thresh: Union[int, float, jnp.ndarray],
        tsyn: Union[int, float, jnp.ndarray],
        tmem: Union[int, float, jnp.ndarray],
        vreset: Union[int, float, jnp.ndarray] = 0,
        blim: Optional[float] = None,
        bmean: Union[int, float, jnp.ndarray] = 0.0,
        init_bias: Optional[Union[int, float, jnp.ndarray]] = None,
        wlim: Optional[float] = None,
        wmean: Union[int, float, jnp.ndarray] = 0.0,
        init_weights: Optional[Union[int, float, jnp.ndarray]] = None,
        fan_in_mode: Optional[str] = None,
        dtype=jnp.float32,
        reset_grad_preserve: bool = True,
    ):
        super().__init__(dtype=dtype)

        self.weights, self.ic = init_weights_and_bias(
            key,
            n_neurons=n_neurons,
            in_size=in_size,
            wlim=wlim,
            wmean=wmean,
            init_weights=init_weights,
            blim=blim,
            bmean=bmean,
            init_bias=init_bias,
            dtype=dtype,
            wmask=wmask,
            fan_in_mode=fan_in_mode,
        )

        self.thresh = StaticArray(jnp.asarray(thresh, dtype=dtype))
        if self.thresh.value.shape not in ((), (n_neurons,)):
            raise ValueError(f"`thresh` must be scalar or shape ({n_neurons},); got {self.thresh.value.shape}")

        self.tmem = StaticArray(jnp.asarray(tmem, dtype=dtype))
        if self.tmem.value.shape not in ((), (n_neurons,)):
            raise ValueError(f"`tmem` must be scalar or shape ({n_neurons},); got {self.tmem.value.shape}")

        self.tsyn = StaticArray(jnp.asarray(tsyn, dtype=dtype))
        if self.tsyn.value.shape not in ((), (n_neurons,)):
            raise ValueError(f"`tsyn` must be scalar or shape ({n_neurons},); got {self.tsyn.value.shape}")

        self.vreset = StaticArray(jnp.asarray(vreset, dtype=dtype))
        if self.vreset.value.shape not in ((), (n_neurons,)):
            raise ValueError(f"`vreset` must be scalar or shape ({n_neurons},); got {self.vreset.value.shape}")

        self.epsilon = jnp.finfo(dtype).eps.item()
        self.reset_grad_preserve = reset_grad_preserve

    def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
        """Return zero-initialised state of shape `(n_neurons, 2)`."""
        return jnp.zeros((n_neurons, 2), dtype=self.dtype)

    def dynamics(
        self,
        t: float,
        y: Float[Array, "neurons 2"],
        args: Dict[str, Any],
    ) -> Float[Array, "neurons 2"]:
        """Compute the LIF ODE derivatives for voltage and synaptic current."""
        v = y[:, 0]
        i = y[:, 1]
        dv_dt = (-v + i + self.ic) / self.tmem.value
        di_dt = -i / self.tsyn.value
        return jnp.stack([dv_dt, di_dt], axis=1)

    def spike_condition(
        self,
        t: float,
        y: Float[Array, "neurons 2"],
        **kwargs: Dict[str, Any],
    ) -> Float[Array, "neurons"]:
        """Return `v - thresh`; sign change triggers a spike."""
        return y[:, 0] - self.thresh.value

    def input_spike(
        self,
        y: Float[Array, "neurons 2"],
        from_idx: Int[Array, ""],
        to_idx: Int[Array, "targets"],
        valid_mask: Bool[Array, "targets"],
    ) -> Float[Array, "neurons 2"]:
        """Add connection weights to the synaptic current of target neurons."""
        delta_i = self.weights[from_idx, to_idx] * valid_mask
        return y.at[to_idx, 1].add(delta_i)

    def reset_spiked(
        self,
        y: Float[Array, "neurons 2"],
        spike_mask: Bool[Array, "neurons"],
    ) -> Float[Array, "neurons 2"]:
        """Reset voltage of spiked neurons and clip to prevent re-triggering."""
        v, i = y[:, 0], y[:, 1]
        delta_v = self.thresh.value - self.vreset.value

        if self.reset_grad_preserve:
            v = v - spike_mask.astype(self.dtype) * delta_v
        else:
            v = jnp.where(spike_mask, self.vreset.value, v)

        v = clip_with_identity_grad(v, self.thresh.value - self.epsilon)
        return jnp.stack([v, i], axis=1)

Parameters

All neuron-level parameters (thresh, tmem, tsyn, vreset) accept either a scalar (shared across all neurons) or an array of shape (n_neurons,) for per-neuron values.

Parameter Description Default
n_neurons Number of neurons in the layer.
in_size Number of input connections.
thresh Spike threshold \(\vartheta\).
tmem Membrane time constant \(\tau_\text{mem}\).
tsyn Synaptic time constant \(\tau_\text{syn}\).
vreset Reset voltage \(V_\text{reset}\). 0
wmask Binary mask for the weight matrix, shape \((N+K) \times N\).
dtype JAX dtype for all computations. jnp.float32

Weight initialisation

Weights and biases are initialised via [init_weights_and_bias][eventax.neuron_models.initializations.init_weights_and_bias].

Parameter Description Default
init_weights If given, used directly as \(W\). None
wmean Mean for random weight initialisation. 0.0
wlim Uniform half-range for random weights. None
fan_in_mode Fan-in scaling mode. None
init_bias If given, used directly as \(I_c\). None
bmean Mean for random bias initialisation. 0.0
blim Uniform half-range for random biases. None

Gradient behaviour

Parameter Description Default
reset_grad_preserve If True, the reset is implemented as \(V = V - (\vartheta - V_\text{reset})\), which preserves gradients through the reset. If False, uses jnp.where which blocks the gradient. True

State layout

The state array y has shape (n_neurons, 2):

Channel Index Description
\(V\) y[:, 0] Membrane voltage
\(I\) y[:, 1] Synaptic current

Initial state is all zeros.


Trainable fields

Field Shape Description
weights \((N+K) \times N\) Connection weight matrix \(W\)
ic \((N,)\) Bias current \(I_c\)

All other fields are static (not optimised by gradient-based training).


Methods

init_state

Return zero-initialised state of shape (n_neurons, 2).

Source code in eventax/neuron_models/lif.py
97
98
99
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
    """Return zero-initialised state of shape `(n_neurons, 2)`."""
    return jnp.zeros((n_neurons, 2), dtype=self.dtype)

Returns zeros of shape (n_neurons, 2).


dynamics

Compute the LIF ODE derivatives for voltage and synaptic current.

Source code in eventax/neuron_models/lif.py
101
102
103
104
105
106
107
108
109
110
111
112
def dynamics(
    self,
    t: float,
    y: Float[Array, "neurons 2"],
    args: Dict[str, Any],
) -> Float[Array, "neurons 2"]:
    """Compute the LIF ODE derivatives for voltage and synaptic current."""
    v = y[:, 0]
    i = y[:, 1]
    dv_dt = (-v + i + self.ic) / self.tmem.value
    di_dt = -i / self.tsyn.value
    return jnp.stack([dv_dt, di_dt], axis=1)

Implements the free dynamics ODE above.


spike_condition

Return v - thresh; sign change triggers a spike.

Source code in eventax/neuron_models/lif.py
114
115
116
117
118
119
120
121
def spike_condition(
    self,
    t: float,
    y: Float[Array, "neurons 2"],
    **kwargs: Dict[str, Any],
) -> Float[Array, "neurons"]:
    """Return `v - thresh`; sign change triggers a spike."""
    return y[:, 0] - self.thresh.value

Returns \(V - \vartheta\). A spike is triggered when this changes sign (crosses zero from below).


input_spike

Add connection weights to the synaptic current of target neurons.

Source code in eventax/neuron_models/lif.py
123
124
125
126
127
128
129
130
131
132
def input_spike(
    self,
    y: Float[Array, "neurons 2"],
    from_idx: Int[Array, ""],
    to_idx: Int[Array, "targets"],
    valid_mask: Bool[Array, "targets"],
) -> Float[Array, "neurons 2"]:
    """Add connection weights to the synaptic current of target neurons."""
    delta_i = self.weights[from_idx, to_idx] * valid_mask
    return y.at[to_idx, 1].add(delta_i)

Adds the connection weights \(W_{n,m}\) to the synaptic current channel of the target neurons. Only entries where valid_mask is True are applied.


reset_spiked

Reset voltage of spiked neurons and clip to prevent re-triggering.

Source code in eventax/neuron_models/lif.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def reset_spiked(
    self,
    y: Float[Array, "neurons 2"],
    spike_mask: Bool[Array, "neurons"],
) -> Float[Array, "neurons 2"]:
    """Reset voltage of spiked neurons and clip to prevent re-triggering."""
    v, i = y[:, 0], y[:, 1]
    delta_v = self.thresh.value - self.vreset.value

    if self.reset_grad_preserve:
        v = v - spike_mask.astype(self.dtype) * delta_v
    else:
        v = jnp.where(spike_mask, self.vreset.value, v)

    v = clip_with_identity_grad(v, self.thresh.value - self.epsilon)
    return jnp.stack([v, i], axis=1)

Resets the voltage of spiked neurons. The behaviour depends on reset_grad_preserve:

  • True (default): \(V = V - (\vartheta - V_\text{reset})\). Preserves the gradient.
  • False: \(V = V_\text{reset}\) via jnp.where, which blocks the gradient through the reset.