Skip to content

Quadratic Integrate and Fire (QIF)

Implementation of the Quadratic Integrate and Fire (QIF) neuron model in phase representation. The QIF extends the LIF and inherits its spike handling (spike_condition, input_spike, reset_spiked) but replaces the voltage dynamics with quadratic ones.

The standard QIF dynamics in voltage space are:

\[ \tau_\text{mem} \frac{\partial V}{\partial t} = V^2 + I + I_c, \qquad \tau_\text{syn} \frac{\partial I}{\partial t} = -I. \]

To avoid the divergence of \(V \to \infty\) at spike time, the model operates in phase space via the coordinate transform following Klos and Memmersheimer:

\[ \varphi = \frac{1}{\pi}\arctan\!\left(\frac{V}{\pi}\right) + \frac{1}{2}, \qquad \varphi \in [0, 1]. \]

Free Dynamics (phase representation)

\[ \begin{aligned} \tau_\text{mem} \frac{\partial \varphi}{\partial t} &= \cos(\pi\varphi)\left[\cos(\pi\varphi) + \frac{1}{\pi}\sin(\pi\varphi)\right] + \frac{1}{\pi^2}\sin^2(\pi\varphi)\,(I + I_c), \\ \tau_\text{syn} \frac{\partial I}{\partial t} &= -I. \end{aligned} \]

Transition Condition

\[ \varphi_n(t_s^-) - 1 = 0, \qquad \text{for any neuron } n. \]

Jumps at Transition

\[ \begin{aligned} \varphi_n(t_s^+) &= 0, \\ I_m\big(t_s^+\big) &= I_m\big(t_s^-\big) + W_{mn}, \quad \forall m. \end{aligned} \]

Where

  • \(\varphi \in \mathbb{R}^N\) is the phase variable (bounded in \([0, 1]\)),
  • \(I \in \mathbb{R}^N\) is the synaptic current,
  • \(I_c \in \mathbb{R}^N\) is the bias current,
  • \(\tau_\text{syn}\) and \(\tau_\text{mem}\) are the time constants,
  • \(W \in \mathbb{R}^{(N+K) \times N}\) is the synaptic weight matrix with number of neurons \(N\) and input size \(K\).

The threshold and reset are fixed at \(\vartheta = 1\) and \(\varphi_\text{reset} = 0\) respectively.

eventax.neuron_models.QIF

Bases: LIF

Quadratic integrate-and-fire neuron in phase representation.

Extends LIF with quadratic voltage dynamics reformulated via the phase variable \(\varphi = \frac{1}{\pi}\arctan\!\left(\frac{V}{\pi}\right) + \frac{1}{2}\).

Source code in eventax/neuron_models/qif.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class QIF(LIF):
    """Quadratic integrate-and-fire neuron in phase representation.

    Extends [`LIF`][eventax.neuron_models.LIF] with quadratic voltage dynamics
    reformulated via the phase variable $\\varphi =
    \\frac{1}{\\pi}\\arctan\\!\\left(\\frac{V}{\\pi}\\right) + \\frac{1}{2}$.
    """

    def __init__(
        self,
        key: PRNGKeyArray,
        n_neurons: int,
        in_size: int,
        wmask: Float[Array, "in_plus_neurons neurons"],
        tmem: Union[int, float, jnp.ndarray],
        tsyn: Union[int, float, jnp.ndarray],
        blim: Optional[float] = None,
        bmean: Union[int, float, jnp.ndarray] = 0.0,
        init_bias: Optional[Union[int, float, jnp.ndarray]] = None,
        wlim: Optional[float] = None,
        wmean: Union[int, float, jnp.ndarray] = 0.0,
        init_weights: Optional[Union[int, float, jnp.ndarray]] = None,
        fan_in_mode: Optional[str] = None,
        dtype=jnp.float32,
    ):
        super().__init__(
            key=key,
            n_neurons=n_neurons,
            in_size=in_size,
            wmask=wmask,
            tmem=tmem,
            tsyn=tsyn,
            blim=blim,
            bmean=bmean,
            init_bias=init_bias,
            wlim=wlim,
            wmean=wmean,
            init_weights=init_weights,
            thresh=1.0,
            vreset=0.0,
            fan_in_mode=fan_in_mode,
            dtype=dtype,
            reset_grad_preserve=False,
        )

    def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
        """Return initial state with $\\varphi = 0.5$ and $I = 0$."""
        phi0: Float[Array, "neurons"] = jnp.full((n_neurons,), 0.5, dtype=self.dtype)
        i0: Float[Array, "neurons"] = jnp.zeros((n_neurons,), dtype=self.dtype)
        return jnp.stack([phi0, i0], axis=1)

    def dynamics(
        self,
        t: float,
        y: Float[Array, "neurons 2"],
        args: Dict[str, Any],
    ) -> Float[Array, "neurons 2"]:
        """Compute the QIF phase-space ODE derivatives."""
        phi: Float[Array, "neurons"] = y[:, 0]
        i: Float[Array, "neurons"] = y[:, 1]
        c = jnp.cos(jnp.pi * phi)
        s = jnp.sin(jnp.pi * phi)
        term1 = c * (c + (1.0 / jnp.pi) * s)
        term2 = (1.0 / (jnp.pi**2)) * (s**2) * (i + self.ic)
        dphi = (term1 + term2) / self.tmem.value
        di = -i / self.tsyn.value
        return jnp.stack([dphi, di], axis=1)

    def observe(
        self,
        y: Float[Array, "neurons 2"],
    ) -> Float[Array, "neurons 2"]:
        """Map phase $\\varphi$ back to voltage via $V = \\pi\\tan\\!\\bigl(\\pi(\\varphi - \\tfrac{1}{2})\\bigr)$."""
        phi: Float[Array, "neurons"] = y[:, 0]
        i: Float[Array, "neurons"] = y[:, 1]
        v: Float[Array, "neurons"] = jnp.pi * jnp.tan(jnp.pi * (phi - 0.5))
        return jnp.stack([v, i], axis=1)

Parameters

The QIF accepts the same parameters as the LIF except that thresh, vreset, and reset_grad_preserve are fixed internally and not exposed.

Parameter Description Default
n_neurons Number of neurons in the layer.
in_size Number of input connections.
tmem Membrane time constant \(\tau_\text{mem}\).
tsyn Synaptic time constant \(\tau_\text{syn}\).
wmask Binary mask for the weight matrix, shape \((N+K) \times N\).
dtype JAX dtype for all computations. jnp.float32

Weight initialisation

Same as LIF — Weight initialisation.

Parameter Description Default
init_weights If given, used directly as \(W\). None
wmean Mean for random weight initialisation. 0.0
wlim Uniform half-range for random weights. None
fan_in_mode Fan-in scaling mode. None
init_bias If given, used directly as \(I_c\). None
bmean Mean for random bias initialisation. 0.0
blim Uniform half-range for random biases. None

State layout

The state array y has shape (n_neurons, 2):

Channel Index Description
\(\varphi\) y[:, 0] Phase variable (bounded in \([0, 1]\))
\(I\) y[:, 1] Synaptic current

Initial state: \(\varphi = 0.5\), \(I = 0\).

Note

The internal state uses the phase variable \(\varphi\), not the voltage \(V\). Use observe to convert back to voltage space.


Trainable fields

Field Shape Description
weights \((N+K) \times N\) Connection weight matrix \(W\)
ic \((N,)\) Bias current \(I_c\)

Methods

init_state

Return initial state with \(\varphi = 0.5\) and \(I = 0\).

Source code in eventax/neuron_models/qif.py
52
53
54
55
56
def init_state(self, n_neurons: int) -> Float[Array, "neurons 2"]:
    """Return initial state with $\\varphi = 0.5$ and $I = 0$."""
    phi0: Float[Array, "neurons"] = jnp.full((n_neurons,), 0.5, dtype=self.dtype)
    i0: Float[Array, "neurons"] = jnp.zeros((n_neurons,), dtype=self.dtype)
    return jnp.stack([phi0, i0], axis=1)

Returns state of shape (n_neurons, 2) with \(\varphi = 0.5\) (corresponding to \(V = 0\)) and \(I = 0\).


dynamics

Compute the QIF phase-space ODE derivatives.

Source code in eventax/neuron_models/qif.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def dynamics(
    self,
    t: float,
    y: Float[Array, "neurons 2"],
    args: Dict[str, Any],
) -> Float[Array, "neurons 2"]:
    """Compute the QIF phase-space ODE derivatives."""
    phi: Float[Array, "neurons"] = y[:, 0]
    i: Float[Array, "neurons"] = y[:, 1]
    c = jnp.cos(jnp.pi * phi)
    s = jnp.sin(jnp.pi * phi)
    term1 = c * (c + (1.0 / jnp.pi) * s)
    term2 = (1.0 / (jnp.pi**2)) * (s**2) * (i + self.ic)
    dphi = (term1 + term2) / self.tmem.value
    di = -i / self.tsyn.value
    return jnp.stack([dphi, di], axis=1)

Implements the QIF free dynamics in phase space. See the equations at the top of this page.


observe

Map phase \(\varphi\) back to voltage via \(V = \pi\tan\!\bigl(\pi(\varphi - \tfrac{1}{2})\bigr)\).

Source code in eventax/neuron_models/qif.py
75
76
77
78
79
80
81
82
83
def observe(
    self,
    y: Float[Array, "neurons 2"],
) -> Float[Array, "neurons 2"]:
    """Map phase $\\varphi$ back to voltage via $V = \\pi\\tan\\!\\bigl(\\pi(\\varphi - \\tfrac{1}{2})\\bigr)$."""
    phi: Float[Array, "neurons"] = y[:, 0]
    i: Float[Array, "neurons"] = y[:, 1]
    v: Float[Array, "neurons"] = jnp.pi * jnp.tan(jnp.pi * (phi - 0.5))
    return jnp.stack([v, i], axis=1)

Maps the phase variable back to voltage space via the inverse transform:

\[ V = \pi\tan\!\bigl(\pi(\varphi - \tfrac{1}{2})\bigr) \]

Returns an array of shape (n_neurons, 2) where channel 0 is now the voltage \(V\) instead of the phase \(\varphi\), and channel 1 is the synaptic current \(I\) (unchanged).


Inherited methods

The following methods are inherited from LIF without modification:

  • spike_condition — triggers when \(\varphi\) crosses \(\vartheta = 1\).
  • input_spike — adds \(W_{n,m}\) to the synaptic current of target neurons.
  • reset_spiked — resets \(\varphi\) to \(0\) via jnp.where (reset_grad_preserve=False).