Skip to content

Izhikevich Neuron

Implementation of the Izhikevich spiking neuron model with a current-based exponential synapse.
The model combines the two-dimensional neuron dynamics of Izhikevich (2003) with an exponentially decaying synaptic current channel, providing smooth temporal integration of incoming spikes.


Free Dynamics

\[ \begin{aligned} \frac{\partial V}{\partial t} &= 0.04\,V^2 + 5\,V + 140 - U + I_c + I, \\[4pt] \frac{\partial U}{\partial t} &= a\,(b\,V - U), \\[4pt] \tau_\text{syn}\,\frac{\partial I}{\partial t} &= -I. \end{aligned} \]

The voltage equation captures the spike-generating dynamics via a quadratic nonlinearity.
The recovery variable \(U\) provides slow negative feedback (e.g.\ K\(^+\) channel activation).
The synaptic current \(I\) integrates incoming spikes and decays exponentially with time constant \(\tau_\text{syn}\).

Transition Condition

\[ \begin{aligned} V_n(t_s^-) - \vartheta &= 0, \\ \dot{V}_n(t_s^-) &\neq 0, \\ \text{for any neuron } n. \end{aligned} \]

Jumps at Transition

When neuron \(n\) emits a spike at time \(t_s\), its postsynaptic targets \(m\) receive a current impulse:

\[ I_m\big(t_s^+\big) = I_m\big(t_s^-\big) + W_{nm}, \quad \forall m. \]

The spiking neuron \(n\) itself undergoes a reset:

\[ \begin{aligned} V_n(t_s^+) &= c, \\ U_n(t_s^+) &= U_n(t_s^-) + d. \end{aligned} \]

Times immediately before a jump are marked with \((-)\), and immediately after with \((+)\).

Where

  • \(V \in \mathbb{R}^N\) — membrane potential of the neurons,
  • \(U \in \mathbb{R}^N\) — recovery variable,
  • \(I \in \mathbb{R}^N\) — synaptic current,
  • \(I_c \in \mathbb{R}^N\) — bias current,
  • \(a, b, c, d\) — dimensionless parameters controlling the neuron type (see Izhikevich, 2003),
  • \(\tau_\text{syn}\) — synaptic time constant,
  • \(\vartheta\) — firing threshold applied to \(V\),
  • \(W \in \mathbb{R}^{(N+K) \times N}\) — synaptic weight matrix with number of neurons \(N\) and input size \(K\).

eventax.neuron_models.Izhikevich

Bases: NeuronModel

Source code in eventax/neuron_models/izhikevich.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class Izhikevich(NeuronModel):
    a: StaticArray = eqx.field(static=True)
    b: StaticArray = eqx.field(static=True)
    c: StaticArray = eqx.field(static=True)
    d: StaticArray = eqx.field(static=True)
    thresh: StaticArray = eqx.field(static=True)
    tau_syn: StaticArray = eqx.field(static=True)
    epsilon: float = eqx.field(static=True)
    reset_grad_preserve: bool = eqx.field(static=True)

    weights: Float[Array, "in_plus_neurons neurons"]
    ic: Float[Array, "neurons"]

    def __init__(
        self,
        key: PRNGKeyArray,
        n_neurons: int,
        in_size: int,
        wmask: Float[Array, "in_plus_neurons neurons"],
        *,
        a: Union[int, float, jnp.ndarray] = 0.02,
        b: Union[int, float, jnp.ndarray] = 0.2,
        c: Union[int, float, jnp.ndarray] = -51.0,
        d: Union[int, float, jnp.ndarray] = 2.0,
        v_thresh: Union[int, float, jnp.ndarray] = 30.0,
        tau_syn: Union[int, float, jnp.ndarray] = 5.0,
        wlim: Optional[float] = None,
        wmean: Union[int, float, jnp.ndarray] = 0.0,
        init_weights: Optional[Union[int, float, jnp.ndarray]] = None,
        blim: Optional[float] = None,
        bmean: Union[int, float, jnp.ndarray] = 0.0,
        init_bias: Optional[Union[int, float, jnp.ndarray]] = None,
        fan_in_mode: Optional[str] = None,
        dtype=jnp.float32,
        reset_grad_preserve: bool = True,
    ):
        super().__init__(dtype=dtype)

        self.weights, self.ic = init_weights_and_bias(
            key,
            n_neurons=n_neurons,
            in_size=in_size,
            wmask=wmask,
            wlim=wlim,
            wmean=wmean,
            init_weights=init_weights,
            blim=blim,
            bmean=bmean,
            init_bias=init_bias,
            dtype=dtype,
            fan_in_mode=fan_in_mode,
        )

        self.a = StaticArray(jnp.asarray(a, dtype=dtype))
        self.b = StaticArray(jnp.asarray(b, dtype=dtype))
        self.c = StaticArray(jnp.asarray(c, dtype=dtype))
        self.d = StaticArray(jnp.asarray(d, dtype=dtype))
        self.thresh = StaticArray(jnp.asarray(v_thresh, dtype=dtype))
        self.tau_syn = StaticArray(jnp.asarray(tau_syn, dtype=dtype))
        self.epsilon = jnp.finfo(dtype).eps.item()
        self.reset_grad_preserve = reset_grad_preserve

    def init_state(self, n_neurons: int) -> Float[Array, "neurons 3"]:
        v0 = jnp.full((n_neurons,), self.c.value, dtype=self.ic.dtype)
        u0 = self.b.value * v0
        i0 = jnp.zeros((n_neurons,), dtype=self.ic.dtype)
        return jnp.stack([v0, u0, i0], axis=1)

    def dynamics(
        self,
        t: float,
        y: Float[Array, "neurons 3"],
        args: Dict[str, Any],
    ) -> Float[Array, "neurons 3"]:
        v = y[:, 0]
        u = y[:, 1]
        i = y[:, 2]

        dv = 0.04 * v**2 + 5.0 * v + 140.0 - u + self.ic + i
        du = self.a.value * (self.b.value * v - u)
        di = -i / self.tau_syn.value

        return jnp.stack([dv, du, di], axis=1)

    def spike_condition(
        self,
        t: float,
        y: Float[Array, "neurons 3"],
        **kwargs: Dict[str, Any],
    ) -> Float[Array, "neurons"]:
        return y[:, 0] - self.thresh.value

    def input_spike(
        self,
        y: Float[Array, "neurons 3"],
        from_idx: Int[Array, ""],
        to_idx: Int[Array, "targets"],
        valid_mask: Bool[Array, "targets"],
    ) -> Float[Array, "neurons 3"]:
        v = y[:, 0]
        u = y[:, 1]
        i = y[:, 2]

        delta_i = self.weights[from_idx, to_idx] * valid_mask
        i = i.at[to_idx].add(delta_i)

        return jnp.stack([v, u, i], axis=1)

    def reset_spiked(
        self,
        y: Float[Array, "neurons 3"],
        spiked_mask: Bool[Array, "neurons"],
    ) -> Float[Array, "neurons 3"]:
        v = y[:, 0]
        u = y[:, 1]
        i = y[:, 2]

        if self.reset_grad_preserve:
            delta_v = self.thresh.value - (self.c.value - self.epsilon)
            v = v - spiked_mask.astype(self.dtype) * delta_v
            u = u + spiked_mask.astype(self.dtype) * self.d.value
        else:
            v = jnp.where(spiked_mask, self.c.value - self.epsilon, v)
            u = jnp.where(spiked_mask, u + self.d.value, u)

        v = clip_with_identity_grad(v, self.thresh.value - self.epsilon)
        return jnp.stack([v, u, i], axis=1)

Parameters

The Izhikevich parameters (a, b, c, d) accept either a scalar (shared across all neurons) or an array of shape (n_neurons,) for per-neuron values, enabling heterogeneous populations (e.g.\ mixing regular spiking and fast spiking neurons).

Parameter Description Default
n_neurons Number of neurons in the layer.
in_size Number of input connections.
a Recovery time scale \(a\). Smaller values yield slower recovery. 0.02
b Recovery sensitivity \(b\). Controls coupling of \(U\) to \(V\). 0.2
c Post-spike reset voltage \(c\). -51.0
d Post-spike recovery increment \(d\). 2.0
v_thresh Spike threshold \(\vartheta\). 30.0
tau_syn Synaptic time constant \(\tau_\text{syn}\). 5.0
wmask Binary mask for the weight matrix, shape \((N+K) \times N\).
dtype JAX dtype for all computations. jnp.float32

Weight initialisation

Weights and biases are initialised via [init_weights_and_bias][eventax.neuron_models.initializations.init_weights_and_bias].

Parameter Description Default
init_weights If given, used directly as \(W\). None
wmean Mean for random weight initialisation. 0.0
wlim Uniform half-range for random weights. None
fan_in_mode Fan-in scaling mode. None
init_bias If given, used directly as \(I_c\). None
bmean Mean for random bias initialisation. 0.0
blim Uniform half-range for random biases. None

Gradient behaviour

Parameter Description Default
reset_grad_preserve If True, the reset is implemented as \(V = V - (\vartheta - c)\), which preserves gradients through the reset. If False, uses jnp.where which blocks the gradient. True

State layout

The state array y has shape (n_neurons, 3):

Channel Index Description
\(V\) y[:, 0] Membrane voltage
\(U\) y[:, 1] Recovery variable
\(I\) y[:, 2] Synaptic current

Initial state: \(V_0 = c\), \(U_0 = b \cdot c\), \(I_0 = 0\).


Trainable fields

Field Shape Description
weights \((N+K) \times N\) Connection weight matrix \(W\)
ic \((N,)\) Bias current \(I_c\)

All other fields are static (not optimised by gradient-based training).


Methods

init_state

Source code in eventax/neuron_models/izhikevich.py
73
74
75
76
77
def init_state(self, n_neurons: int) -> Float[Array, "neurons 3"]:
    v0 = jnp.full((n_neurons,), self.c.value, dtype=self.ic.dtype)
    u0 = self.b.value * v0
    i0 = jnp.zeros((n_neurons,), dtype=self.ic.dtype)
    return jnp.stack([v0, u0, i0], axis=1)

Returns state of shape (n_neurons, 3) with \(V = c\), \(U = b \cdot c\), \(I = 0\).


dynamics

Source code in eventax/neuron_models/izhikevich.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def dynamics(
    self,
    t: float,
    y: Float[Array, "neurons 3"],
    args: Dict[str, Any],
) -> Float[Array, "neurons 3"]:
    v = y[:, 0]
    u = y[:, 1]
    i = y[:, 2]

    dv = 0.04 * v**2 + 5.0 * v + 140.0 - u + self.ic + i
    du = self.a.value * (self.b.value * v - u)
    di = -i / self.tau_syn.value

    return jnp.stack([dv, du, di], axis=1)

Implements the free dynamics equations above. The voltage \(V\) evolves via the characteristic quadratic nonlinearity with recovery feedback from \(U\) and drive from the synaptic current \(I\) and bias \(I_c\). The recovery variable \(U\) relaxes toward \(b \cdot V\) with rate \(a\). The synaptic current \(I\) decays exponentially with time constant \(\tau_\text{syn}\).


spike_condition

Source code in eventax/neuron_models/izhikevich.py
 95
 96
 97
 98
 99
100
101
def spike_condition(
    self,
    t: float,
    y: Float[Array, "neurons 3"],
    **kwargs: Dict[str, Any],
) -> Float[Array, "neurons"]:
    return y[:, 0] - self.thresh.value

Returns \(V - \vartheta\). A spike is triggered when this changes sign (crosses zero from below).


input_spike

Source code in eventax/neuron_models/izhikevich.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def input_spike(
    self,
    y: Float[Array, "neurons 3"],
    from_idx: Int[Array, ""],
    to_idx: Int[Array, "targets"],
    valid_mask: Bool[Array, "targets"],
) -> Float[Array, "neurons 3"]:
    v = y[:, 0]
    u = y[:, 1]
    i = y[:, 2]

    delta_i = self.weights[from_idx, to_idx] * valid_mask
    i = i.at[to_idx].add(delta_i)

    return jnp.stack([v, u, i], axis=1)

Adds the connection weights \(W_{n,m}\) to the synaptic current channel of the target neurons. Only entries where valid_mask is True are applied.


reset_spiked

Source code in eventax/neuron_models/izhikevich.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def reset_spiked(
    self,
    y: Float[Array, "neurons 3"],
    spiked_mask: Bool[Array, "neurons"],
) -> Float[Array, "neurons 3"]:
    v = y[:, 0]
    u = y[:, 1]
    i = y[:, 2]

    if self.reset_grad_preserve:
        delta_v = self.thresh.value - (self.c.value - self.epsilon)
        v = v - spiked_mask.astype(self.dtype) * delta_v
        u = u + spiked_mask.astype(self.dtype) * self.d.value
    else:
        v = jnp.where(spiked_mask, self.c.value - self.epsilon, v)
        u = jnp.where(spiked_mask, u + self.d.value, u)

    v = clip_with_identity_grad(v, self.thresh.value - self.epsilon)
    return jnp.stack([v, u, i], axis=1)

Resets voltage and increments recovery for spiked neurons. The behaviour depends on reset_grad_preserve:

  • True (default): \(V = V - (\vartheta - c)\) and \(U = U + d\). Preserves the gradient.
  • False: \(V = c\) via jnp.where, which blocks the gradient through the reset; \(U = U + d\) unchanged.

The synaptic current \(I\) is left unchanged by the reset.