Skip to content

Getting started

Usage example

eventax.neuron_models.base_model

StaticArray dataclass

A wrapper around JAX arrays that should not be optimized.

Wraps a jnp.array in a frozen dataclass to avoid the warning raised by eqx.field(static=True) on unhashable JAX arrays, while preserving hashability via object identity.

Source code in eventax/neuron_models/base_model.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@dataclass(frozen=True, eq=False)
class StaticArray:
    """A wrapper around JAX arrays that should not be optimized.

    Wraps a `jnp.array` in a frozen dataclass to avoid the warning raised by
    `eqx.field(static=True)` on unhashable JAX arrays, while preserving
    hashability via object identity.
    """

    value: Array
    """The underlying JAX array."""

    def __eq__(self, other):
        if not isinstance(other, StaticArray):
            return False
        return self.value is other.value

    def __hash__(self):
        return hash(id(self.value))

value instance-attribute

value: Array

The underlying JAX array.

NeuronModel

Bases: Module

Abstract base class for all neuron models.

Subclasses must implement init_state, dynamics, spike_condition, input_spike, and reset_spiked.

Source code in eventax/neuron_models/base_model.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class NeuronModel(eqx.Module):
    """Abstract base class for all neuron models.

    Subclasses must implement [`init_state`][eventax.neuron_models.base_model.NeuronModel.init_state],
    [`dynamics`][eventax.neuron_models.base_model.NeuronModel.dynamics],
    [`spike_condition`][eventax.neuron_models.base_model.NeuronModel.spike_condition],
    [`input_spike`][eventax.neuron_models.base_model.NeuronModel.input_spike],
    and [`reset_spiked`][eventax.neuron_models.base_model.NeuronModel.reset_spiked].
    """

    dtype: jnp.dtype = eqx.field(static=True)
    """The JAX dtype used for all internal computations."""

    def __init__(self, dtype: jnp.dtype):
        self.dtype = dtype

    def __call__(self, t, y, args):
        return self.dynamics(t, y, args)

    def init_state(self, n_neurons: int) -> Any:
        """Return the initial state for all neurons in the network."""
        raise NotImplementedError

    def dynamics(
        self,
        t: float,
        y: Any,
        args: Dict[str, Any],
    ) -> Any:
        """Compute the time derivative of the neuron state."""
        raise NotImplementedError

    def spike_condition(
        self,
        t: float,
        y: Any,
        args: Dict[str, Any],
    ) -> Float[Array, "neurons"]:
        """Evaluate the spike condition for each neuron."""
        raise NotImplementedError

    def input_spike(
        self,
        y: Any,
        from_idx: Int[Array, ""],
        to_idx: Int[Array, "targets"],
        valid_mask: Bool[Array, "targets"],
    ) -> Any:
        """Update state in response to an incoming spike."""
        raise NotImplementedError

    def reset_spiked(
        self,
        y: Any,
        spiked_mask: Bool[Array, "neurons"],
    ) -> Any:
        """Reset the state of neurons that just spiked."""
        raise NotImplementedError

    def observe(
        self,
        y: Any,
    ) -> Float[Array, "neurons obs_channels"]:
        """Extract observable channels from the state."""
        return y

dtype class-attribute instance-attribute

dtype: dtype = dtype

The JAX dtype used for all internal computations.

init_state

init_state(n_neurons: int) -> Any

Return the initial state for all neurons in the network.

Source code in eventax/neuron_models/base_model.py
48
49
50
def init_state(self, n_neurons: int) -> Any:
    """Return the initial state for all neurons in the network."""
    raise NotImplementedError

dynamics

dynamics(t: float, y: Any, args: Dict[str, Any]) -> Any

Compute the time derivative of the neuron state.

Source code in eventax/neuron_models/base_model.py
52
53
54
55
56
57
58
59
def dynamics(
    self,
    t: float,
    y: Any,
    args: Dict[str, Any],
) -> Any:
    """Compute the time derivative of the neuron state."""
    raise NotImplementedError

spike_condition

spike_condition(t: float, y: Any, args: Dict[str, Any]) -> Float[Array, neurons]

Evaluate the spike condition for each neuron.

Source code in eventax/neuron_models/base_model.py
61
62
63
64
65
66
67
68
def spike_condition(
    self,
    t: float,
    y: Any,
    args: Dict[str, Any],
) -> Float[Array, "neurons"]:
    """Evaluate the spike condition for each neuron."""
    raise NotImplementedError

input_spike

input_spike(y: Any, from_idx: Int[Array, ''], to_idx: Int[Array, targets], valid_mask: Bool[Array, targets]) -> Any

Update state in response to an incoming spike.

Source code in eventax/neuron_models/base_model.py
70
71
72
73
74
75
76
77
78
def input_spike(
    self,
    y: Any,
    from_idx: Int[Array, ""],
    to_idx: Int[Array, "targets"],
    valid_mask: Bool[Array, "targets"],
) -> Any:
    """Update state in response to an incoming spike."""
    raise NotImplementedError

reset_spiked

reset_spiked(y: Any, spiked_mask: Bool[Array, neurons]) -> Any

Reset the state of neurons that just spiked.

Source code in eventax/neuron_models/base_model.py
80
81
82
83
84
85
86
def reset_spiked(
    self,
    y: Any,
    spiked_mask: Bool[Array, "neurons"],
) -> Any:
    """Reset the state of neurons that just spiked."""
    raise NotImplementedError

observe

observe(y: Any) -> Float[Array, 'neurons obs_channels']

Extract observable channels from the state.

Source code in eventax/neuron_models/base_model.py
88
89
90
91
92
93
def observe(
    self,
    y: Any,
) -> Float[Array, "neurons obs_channels"]:
    """Extract observable channels from the state."""
    return y