Getting started¶
Usage example¶
eventax.neuron_models.base_model ¶
StaticArray
dataclass
¶
A wrapper around JAX arrays that should not be optimized.
Wraps a jnp.array in a frozen dataclass to avoid the warning raised by
eqx.field(static=True) on unhashable JAX arrays, while preserving
hashability via object identity.
Source code in eventax/neuron_models/base_model.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | |
NeuronModel ¶
Bases: Module
Abstract base class for all neuron models.
Subclasses must implement init_state,
dynamics,
spike_condition,
input_spike,
and reset_spiked.
Source code in eventax/neuron_models/base_model.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | |
dtype
class-attribute
instance-attribute
¶
dtype: dtype = dtype
The JAX dtype used for all internal computations.
init_state ¶
init_state(n_neurons: int) -> Any
Return the initial state for all neurons in the network.
Source code in eventax/neuron_models/base_model.py
48 49 50 | |
dynamics ¶
dynamics(t: float, y: Any, args: Dict[str, Any]) -> Any
Compute the time derivative of the neuron state.
Source code in eventax/neuron_models/base_model.py
52 53 54 55 56 57 58 59 | |
spike_condition ¶
spike_condition(t: float, y: Any, args: Dict[str, Any]) -> Float[Array, neurons]
Evaluate the spike condition for each neuron.
Source code in eventax/neuron_models/base_model.py
61 62 63 64 65 66 67 68 | |
input_spike ¶
input_spike(y: Any, from_idx: Int[Array, ''], to_idx: Int[Array, targets], valid_mask: Bool[Array, targets]) -> Any
Update state in response to an incoming spike.
Source code in eventax/neuron_models/base_model.py
70 71 72 73 74 75 76 77 78 | |
reset_spiked ¶
reset_spiked(y: Any, spiked_mask: Bool[Array, neurons]) -> Any
Reset the state of neurons that just spiked.
Source code in eventax/neuron_models/base_model.py
80 81 82 83 84 85 86 | |
observe ¶
observe(y: Any) -> Float[Array, 'neurons obs_channels']
Extract observable channels from the state.
Source code in eventax/neuron_models/base_model.py
88 89 90 91 92 93 | |