Skip to content

Neuron Models

A NeuronModel defines the behaviour of a neuron type within the event-driven simulation. It specifies:

All neuron models inherit from the abstract base class below.


eventax.neuron_models.base_model.NeuronModel

Bases: Module

Abstract base class for all neuron models.

Subclasses must implement init_state, dynamics, spike_condition, input_spike, and reset_spiked.

Source code in eventax/neuron_models/base_model.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class NeuronModel(eqx.Module):
    """Abstract base class for all neuron models.

    Subclasses must implement [`init_state`][eventax.neuron_models.base_model.NeuronModel.init_state],
    [`dynamics`][eventax.neuron_models.base_model.NeuronModel.dynamics],
    [`spike_condition`][eventax.neuron_models.base_model.NeuronModel.spike_condition],
    [`input_spike`][eventax.neuron_models.base_model.NeuronModel.input_spike],
    and [`reset_spiked`][eventax.neuron_models.base_model.NeuronModel.reset_spiked].
    """

    dtype: jnp.dtype = eqx.field(static=True)
    """The JAX dtype used for all internal computations."""

    def __init__(self, dtype: jnp.dtype):
        self.dtype = dtype

    def __call__(self, t, y, args):
        return self.dynamics(t, y, args)

    def init_state(self, n_neurons: int) -> Any:
        """Return the initial state for all neurons in the network."""
        raise NotImplementedError

    def dynamics(
        self,
        t: float,
        y: Any,
        args: Dict[str, Any],
    ) -> Any:
        """Compute the time derivative of the neuron state."""
        raise NotImplementedError

    def spike_condition(
        self,
        t: float,
        y: Any,
        args: Dict[str, Any],
    ) -> Float[Array, "neurons"]:
        """Evaluate the spike condition for each neuron."""
        raise NotImplementedError

    def input_spike(
        self,
        y: Any,
        from_idx: Int[Array, ""],
        to_idx: Int[Array, "targets"],
        valid_mask: Bool[Array, "targets"],
    ) -> Any:
        """Update state in response to an incoming spike."""
        raise NotImplementedError

    def reset_spiked(
        self,
        y: Any,
        spiked_mask: Bool[Array, "neurons"],
    ) -> Any:
        """Reset the state of neurons that just spiked."""
        raise NotImplementedError

    def observe(
        self,
        y: Any,
    ) -> Float[Array, "neurons obs_channels"]:
        """Extract observable channels from the state."""
        return y

Methods

All of the following methods must be implemented by any concrete NeuronModel subclass. See Examples: Neuron Models for a walkthrough of creating your own neuron model.


init_state

Return the initial state for all neurons in the network.

Source code in eventax/neuron_models/base_model.py
48
49
50
def init_state(self, n_neurons: int) -> Any:
    """Return the initial state for all neurons in the network."""
    raise NotImplementedError

Should return the initial state for all neurons and their channels in the network. The returned array has shape (n_neurons, n_channels) where n_channels is determined by the specific neuron model (e.g. 2 for a LIF/QIF).


dynamics

Compute the time derivative of the neuron state.

Source code in eventax/neuron_models/base_model.py
52
53
54
55
56
57
58
59
def dynamics(
    self,
    t: float,
    y: Any,
    args: Dict[str, Any],
) -> Any:
    """Compute the time derivative of the neuron state."""
    raise NotImplementedError

Defines the continuous-time dynamics of all neuron channels as an ODE. Given the current time \(t\) and full state \(y\) over all neurons, returns the derivative \(\frac{\mathrm{d}y}{\mathrm{d}t}\).

Arguments:

  • t: the current simulation time.
  • y: the state of all neurons, shape (n_neurons, n_channels).
  • args: a dictionary of additional arguments (e.g. external input currents).

Returns:

The derivative of y, same shape as y.


spike_condition

Evaluate the spike condition for each neuron.

Source code in eventax/neuron_models/base_model.py
61
62
63
64
65
66
67
68
def spike_condition(
    self,
    t: float,
    y: Any,
    args: Dict[str, Any],
) -> Float[Array, "neurons"]:
    """Evaluate the spike condition for each neuron."""
    raise NotImplementedError

Defines the condition under which a neuron generates a new event. Returns a vector over all neurons. An event for neuron \(m\) is triggered when the value at index \(m\) changes its sign from negative to positive from one ODE solver step to the next. A root finder then locates the exact event time.

Example

For a simple threshold crossing at \(V_\mathrm{th}\), this would return y[:, 0] - V_th.

Arguments:

  • t: the current simulation time.
  • y: the state of all neurons, shape (n_neurons, n_channels).
  • args: a dictionary of additional arguments.

Returns:

A vector of shape (n_neurons,).


input_spike

Update state in response to an incoming spike.

Source code in eventax/neuron_models/base_model.py
70
71
72
73
74
75
76
77
78
def input_spike(
    self,
    y: Any,
    from_idx: Int[Array, ""],
    to_idx: Int[Array, "targets"],
    valid_mask: Bool[Array, "targets"],
) -> Any:
    """Update state in response to an incoming spike."""
    raise NotImplementedError

Defines how the internal state is updated when a spike is received. Must handle the case where multiple target neurons receive the spike simultaneously (i.e. to_idx may contain multiple indices).

Arguments:

  • y: the current state of all neurons.
  • from_idx: scalar index of the neuron that fired.
  • to_idx: indices of the target neurons, shape (targets,).
  • valid_mask: boolean mask indicating which entries in to_idx are valid, shape (targets,).

Returns:

The updated state y.


reset_spiked

Reset the state of neurons that just spiked.

Source code in eventax/neuron_models/base_model.py
80
81
82
83
84
85
86
def reset_spiked(
    self,
    y: Any,
    spiked_mask: Bool[Array, "neurons"],
) -> Any:
    """Reset the state of neurons that just spiked."""
    raise NotImplementedError

Defines how the state of a neuron is reset after it fires. Receives the current state and a one-hot mask indicating which neuron spiked.

Arguments:

  • y: the current state of all neurons.
  • spiked_mask: boolean array of shape (n_neurons,), True for the neuron that spiked.

Returns:

The updated state y.


observe

Extract observable channels from the state.

Source code in eventax/neuron_models/base_model.py
88
89
90
91
92
93
def observe(
    self,
    y: Any,
) -> Float[Array, "neurons obs_channels"]:
    """Extract observable channels from the state."""
    return y

Extracts the observable channels from the full internal state. By default returns y unchanged. Override this in models where the internal state contains channels that should not be recorded (e.g. auxiliary solver variables).

Returns:

Array of shape (n_neurons, obs_channels).