Skip to content

Event-based Gated Recurrent Unit (EGRU)

Implementation of the Event-based Gated Recurrent Unit (EGRU) model.
This neuron model combines exponential relaxation dynamics with gated recurrent interactions triggered by discrete events (spikes).
It generalizes recurrent neural units to the event-driven setting and provides continuous-time dynamics between spikes.


Free Dynamics

Let

\[ u = \sigma(a_u), \qquad r = \sigma(a_r), \qquad z = \tanh(a_z), \]

where \(\sigma(\cdot)\) is the logistic sigmoid.
Each neuron maintains a continuous cell state \(c\) and three exponentially relaxing pre-activation variables \(a_u, a_r, a_z\):

\[ \begin{aligned} \tau_\text{mem}\,\frac{\partial c}{\partial t} &= u\,(z - c), \\[4pt] \tau_\text{syn}\,\frac{\partial a_u}{\partial t} &= -a_u + b_u,\\ \tau_\text{syn}\,\frac{\partial a_r}{\partial t} &= -a_r + b_r,\\ \tau_\text{syn}\,\frac{\partial a_z}{\partial t} &= -a_z + b_z. \end{aligned} \]

These equations describe exponentially decaying internal activations and a gated relaxation of the state \(c\) toward the modulation signal \(z\).

Transition Condition

A spike (event) is emitted whenever the continuous state \(c_n\) of neuron \(n\) reaches threshold:

\[ \begin{aligned} c_n(t_s^-) - \vartheta &= 0, \\ \dot{c}_n(t_s^-) &\neq 0, \\ \text{for any neuron } n. \end{aligned} \]

Jumps at Transition

When neuron \(n\) emits a spike at time \(t_s\), its postsynaptic targets \(m\) experience instantaneous jumps:

\[ \begin{aligned} a_{u,m}\big(t_s^+\big) &= a_{u,m}\big(t_s^-\big) + W^{(u)}_{mn}\, c_n(t_s^-), \\[4pt] a_{r,m}\big(t_s^+\big) &= a_{r,m}\big(t_s^-\big) + W^{(r)}_{mn}\, c_n(t_s^-), \\[4pt] a_{z,m}\big(t_s^+\big) &= a_{z,m}\big(t_s^-\big) + W^{(z)}_{mn}\, r_n(t_s^-)\, c_n(t_s^-). \end{aligned} \]

Here \(r_n = \sigma(a_{r,n})\) is the presynaptic reset gate, controlling the contribution of neuron \(n\)'s spikes to the \(a_z\) channel of its targets.

The spiking neuron \(n\) itself resets its internal state:

\[ c_n(t_s^+) = 0. \]

Times immediately before a jump are marked with \((-)\), and immediately after with \((+)\).

Where

  • \(c \in \mathbb{R}^N\) — continuous cell state of the neurons,
  • \(a_u, a_r, a_z \in \mathbb{R}^N\) — pre-activations for the update, reset, and modulation gates,
  • \(u=\sigma(a_u)\), \(r=\sigma(a_r)\), \(z=\tanh(a_z)\) — corresponding gate activations,
  • \(b_u, b_r, b_z \in \mathbb{R}^N\) — bias levels (exponential attractors) for the pre-activations,
  • \(\tau_\text{mem}\), \(\tau_\text{syn}\) — membrane and synaptic time constants,
  • \(W^{(u)}, W^{(r)}, W^{(z)} \in \mathbb{R}^{(N+K)\times N}\) — synaptic weight matrices for the three channels,
  • \(\vartheta\) — firing threshold applied to \(c\),
  • \(N\) — number of neurons, \(K\) — number of input channels.

eventax.neuron_models.EGRU

Bases: NeuronModel

Event-based gated recurrent unit with continuous-time dynamics.

Four state channels: cell state \(c\) and three pre-activations \(a_u, a_r, a_z\).

Source code in eventax/neuron_models/egru.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class EGRU(NeuronModel):
    """Event-based gated recurrent unit with continuous-time dynamics.

    Four state channels: cell state $c$ and three pre-activations $a_u, a_r, a_z$.
    """

    W_u: Float[Array, "in_plus_neurons neurons"]
    """Weight matrix for the update gate."""

    W_r: Float[Array, "in_plus_neurons neurons"]
    """Weight matrix for the reset gate."""

    W_z: Float[Array, "in_plus_neurons neurons"]
    """Weight matrix for the modulation gate."""

    b_u: Float[Array, "neurons"]
    """Bias attractor for the update pre-activation $a_u$."""

    b_r: Float[Array, "neurons"]
    """Bias attractor for the reset pre-activation $a_r$."""

    b_z: Float[Array, "neurons"]
    """Bias attractor for the modulation pre-activation $a_z$."""

    tsyn: StaticArray = eqx.field(static=True)
    """Synaptic time constant $\\tau_\\mathrm{syn}$."""

    tmem: StaticArray = eqx.field(static=True)
    """Membrane time constant $\\tau_\\mathrm{mem}$."""

    thresh: StaticArray = eqx.field(static=True)
    """Spike threshold $\\vartheta$ applied to cell state $c$."""

    epsilon: float = eqx.field(static=True)
    """Machine epsilon for the chosen dtype."""

    def __init__(
        self,
        key: PRNGKeyArray,
        n_neurons: int,
        in_size: int,
        wmask: Float[Array, "in_plus_neurons neurons"],
        tsyn: Union[int, float, jnp.ndarray] = 5.0,
        tmem: Union[int, float, jnp.ndarray] = 20.0,
        thresh: Union[int, float, jnp.ndarray] = 0.5,
        bias_u_mean: float = -2.0,
        bias_r_mean: float = 1.0,
        bias_z_mean: float = 0.0,
        bias_scale: float = 0.1,
        weight_scale: float = 1.0,
        weight_mean: float = 0.0,
        dtype=jnp.float32,
    ):
        super().__init__(dtype=dtype)

        k1, k2, k3, k4, k5, k6 = jax.random.split(key, 6)

        self.b_u = bias_scale * jax.random.normal(k1, (n_neurons,), dtype=dtype) + bias_u_mean
        self.b_r = bias_scale * jax.random.normal(k2, (n_neurons,), dtype=dtype) + bias_r_mean
        self.b_z = bias_scale * jax.random.normal(k3, (n_neurons,), dtype=dtype) + bias_z_mean

        self.W_u = weight_scale * jax.random.normal(k4, (n_neurons + in_size, n_neurons), dtype=dtype) + weight_mean
        self.W_r = weight_scale * jax.random.normal(k5, (n_neurons + in_size, n_neurons), dtype=dtype) + weight_mean
        self.W_z = weight_scale * jax.random.normal(k6, (n_neurons + in_size, n_neurons), dtype=dtype) + weight_mean

        self.tsyn = StaticArray(jnp.asarray(tsyn, dtype=dtype))
        self.tmem = StaticArray(jnp.asarray(tmem, dtype=dtype))
        self.thresh = StaticArray(jnp.asarray(thresh, dtype=dtype))
        self.epsilon = jnp.finfo(dtype).eps.item()

    def init_state(self, n_neurons: int) -> Float[Array, "neurons 4"]:
        """Return zero-initialised state of shape `(n_neurons, 4)`."""
        return jnp.zeros((n_neurons, 4), dtype=self.dtype)

    def dynamics(
        self,
        t: float,
        y: Float[Array, "neurons 4"],
        args: Dict[str, Any],
    ) -> Float[Array, "neurons 4"]:
        """Compute the EGRU ODE derivatives for cell state and pre-activations."""
        c = y[:, 0]
        a_u = y[:, 1]
        a_r = y[:, 2]
        a_z = y[:, 3]

        da_u = (-a_u + self.b_u) / self.tsyn.value
        da_r = (-a_r + self.b_r) / self.tsyn.value
        da_z = (-a_z + self.b_z) / self.tsyn.value

        u = jax.nn.sigmoid(a_u)
        z = jnp.tanh(a_z)
        dc = (u * (z - c)) / self.tmem.value

        return jnp.stack([dc, da_u, da_r, da_z], axis=1)

    def spike_condition(
        self,
        t: float,
        y: Float[Array, "neurons 4"],
        **kwargs: Dict[str, Any],
    ) -> Float[Array, "neurons"]:
        """Return `c - thresh`; sign change triggers a spike."""
        return y[:, 0] - self.thresh.value

    def input_spike(
        self,
        y: Float[Array, "neurons 4"],
        from_idx: Union[int, Int[Array, ""]],
        to_idx: Int[Array, "targets"],
        valid_mask: Bool[Array, "targets"],
    ) -> Float[Array, "neurons 4"]:
        """Add gated weight contributions to the pre-activation channels of target neurons."""
        res = jax.nn.sigmoid(y[to_idx, 2])
        du = self.W_u[from_idx, to_idx] * valid_mask
        dr = self.W_r[from_idx, to_idx] * valid_mask
        dz = self.W_z[from_idx, to_idx] * res * valid_mask

        y = y.at[to_idx, 1].add(du)
        y = y.at[to_idx, 2].add(dr)
        y = y.at[to_idx, 3].add(dz)
        return y

    def reset_spiked(
        self,
        y: Float[Array, "neurons 4"],
        spiked_mask: Bool[Array, "neurons"],
    ) -> Float[Array, "neurons 4"]:
        """Reset cell state of spiked neurons via subtraction and clip to prevent re-triggering."""
        c, a_u, a_r, a_z = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
        c = c - spiked_mask.astype(self.dtype) * self.thresh.value
        c = clip_with_identity_grad(c, self.thresh.value - self.epsilon)
        return jnp.stack([c, a_u, a_r, a_z], axis=1)

Parameters

Parameter Description Default
n_neurons Number of neurons in the layer.
in_size Number of input connections.
tmem Membrane time constant \(\tau_\text{mem}\). 20.0
tsyn Synaptic time constant \(\tau_\text{syn}\). 5.0
thresh Spike threshold \(\vartheta\) on cell state \(c\). 0.5
wmask Binary mask for the weight matrices, shape \((N+K) \times N\).
dtype JAX dtype for all computations. jnp.float32

Initialisation

Weights and biases are drawn from normal distributions controlled by scale and mean parameters.

Parameter Description Default
weight_scale Standard deviation for weight initialisation. 5.0
weight_mean Mean for weight initialisation. 1.0
bias_scale Standard deviation for bias initialisation. 0.1
bias_mean Mean for bias initialisation. 0.4

State layout

The state array y has shape (n_neurons, 4):

Channel Index Description
\(c\) y[:, 0] Cell state
\(a_u\) y[:, 1] Update gate pre-activation
\(a_r\) y[:, 2] Reset gate pre-activation
\(a_z\) y[:, 3] Modulation gate pre-activation

Initial state is all zeros.


Trainable fields

Field Shape Description
W_u \((N+K) \times N\) Weight matrix \(W^{(u)}\) for the update gate
W_r \((N+K) \times N\) Weight matrix \(W^{(r)}\) for the reset gate
W_z \((N+K) \times N\) Weight matrix \(W^{(z)}\) for the modulation gate
b_u \((N,)\) Bias attractor \(b_u\) for update pre-activation
b_r \((N,)\) Bias attractor \(b_r\) for reset pre-activation
b_z \((N,)\) Bias attractor \(b_z\) for modulation pre-activation

All other fields are static (not optimised by gradient-based training).


Methods

init_state

Return zero-initialised state of shape (n_neurons, 4).

Source code in eventax/neuron_models/egru.py
80
81
82
def init_state(self, n_neurons: int) -> Float[Array, "neurons 4"]:
    """Return zero-initialised state of shape `(n_neurons, 4)`."""
    return jnp.zeros((n_neurons, 4), dtype=self.dtype)

Returns zeros of shape (n_neurons, 4).


dynamics

Compute the EGRU ODE derivatives for cell state and pre-activations.

Source code in eventax/neuron_models/egru.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def dynamics(
    self,
    t: float,
    y: Float[Array, "neurons 4"],
    args: Dict[str, Any],
) -> Float[Array, "neurons 4"]:
    """Compute the EGRU ODE derivatives for cell state and pre-activations."""
    c = y[:, 0]
    a_u = y[:, 1]
    a_r = y[:, 2]
    a_z = y[:, 3]

    da_u = (-a_u + self.b_u) / self.tsyn.value
    da_r = (-a_r + self.b_r) / self.tsyn.value
    da_z = (-a_z + self.b_z) / self.tsyn.value

    u = jax.nn.sigmoid(a_u)
    z = jnp.tanh(a_z)
    dc = (u * (z - c)) / self.tmem.value

    return jnp.stack([dc, da_u, da_r, da_z], axis=1)

Implements the free dynamics equations above. The pre-activations \(a_u, a_r, a_z\) relax exponentially toward their bias attractors \(b_u, b_r, b_z\) with time constant \(\tau_\text{syn}\). The cell state \(c\) evolves via gated relaxation toward \(z = \tanh(a_z)\), modulated by the update gate \(u = \sigma(a_u)\), with time constant \(\tau_\text{mem}\).


spike_condition

Return c - thresh; sign change triggers a spike.

Source code in eventax/neuron_models/egru.py
106
107
108
109
110
111
112
113
def spike_condition(
    self,
    t: float,
    y: Float[Array, "neurons 4"],
    **kwargs: Dict[str, Any],
) -> Float[Array, "neurons"]:
    """Return `c - thresh`; sign change triggers a spike."""
    return y[:, 0] - self.thresh.value

Returns \(c - \vartheta\). A spike is triggered when this changes sign (crosses zero from below).


input_spike

Add gated weight contributions to the pre-activation channels of target neurons.

Source code in eventax/neuron_models/egru.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def input_spike(
    self,
    y: Float[Array, "neurons 4"],
    from_idx: Union[int, Int[Array, ""]],
    to_idx: Int[Array, "targets"],
    valid_mask: Bool[Array, "targets"],
) -> Float[Array, "neurons 4"]:
    """Add gated weight contributions to the pre-activation channels of target neurons."""
    res = jax.nn.sigmoid(y[to_idx, 2])
    du = self.W_u[from_idx, to_idx] * valid_mask
    dr = self.W_r[from_idx, to_idx] * valid_mask
    dz = self.W_z[from_idx, to_idx] * res * valid_mask

    y = y.at[to_idx, 1].add(du)
    y = y.at[to_idx, 2].add(dr)
    y = y.at[to_idx, 3].add(dz)
    return y

Updates the pre-activation channels of target neurons when a spike arrives. The three channels receive different contributions:

  • \(a_u\): receives \(W^{(u)}_{n,m}\)
  • \(a_r\): receives \(W^{(r)}_{n,m}\)
  • \(a_z\): receives \(W^{(z)}_{n,m} \cdot r_m\), gated by the target neuron's reset gate \(r_m = \sigma(a_{r,m})\)

Only entries where valid_mask is True are applied.


reset_spiked

Reset cell state of spiked neurons via subtraction and clip to prevent re-triggering.

Source code in eventax/neuron_models/egru.py
133
134
135
136
137
138
139
140
141
142
def reset_spiked(
    self,
    y: Float[Array, "neurons 4"],
    spiked_mask: Bool[Array, "neurons"],
) -> Float[Array, "neurons 4"]:
    """Reset cell state of spiked neurons via subtraction and clip to prevent re-triggering."""
    c, a_u, a_r, a_z = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
    c = c - spiked_mask.astype(self.dtype) * self.thresh.value
    c = clip_with_identity_grad(c, self.thresh.value - self.epsilon)
    return jnp.stack([c, a_u, a_r, a_z], axis=1)

Resets the cell state of spiked neurons via subtraction: \(c = c - \vartheta\). This preserves the gradient. The pre-activation channels \(a_u, a_r, a_z\) are left unchanged.