Skip to content

Event-driven Neural Network (EvNN)

Event-driven neural network core built with JAX, Equinox, and Diffrax.


High-level overview

EvNN simulates a network of spiking neurons using an event-based formulation:

  • An ODE solver integrates the neuron state from the time of the last event to the next event.
  • Event times are stored in a sorted buffer.
  • At every iteration, we integrate the neuron states from the first event in the buffer to either the next event or until one of the neurons causes a new event.
  • When a neuron causes an event, its state is reset and the new events are enqueued.

[FFEvNN][eventax.evnn.FFEvNN] is a convenience subclass for creating feed-forward networks from a list of layer sizes, automatically wiring input and output neuron masks.

The network delegates all neuron-specific behaviour to a pluggable NeuronModel, which must expose a consistent interface (see NeuronModel).

eventax.evnn.EvNN

Bases: Module

Source code in eventax/evnn.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
class EvNN(eqx.Module):
    syn_conn: Int[Array, "in_plus_neurons max_syn"]
    max_syn_conn: int = eqx.field(static=True)
    t0: float = eqx.field(static=True)
    dtype: jnp.dtype = eqx.field(static=True)
    delays: Optional[Float[Array, "in_plus_neurons neurons"]]
    axonal_delays: Optional[Float[Array, "in_plus_neurons"]]
    neuron_model: NeuronModel
    n_neurons: int = eqx.field(static=True)
    buffer_capacity: int = eqx.field(static=True)
    max_solver_steps: int = eqx.field(static=True)
    max_solver_time: float = eqx.field(static=True)
    solver_stepsize: float = eqx.field(static=True)
    max_event_steps: int = eqx.field(static=True)
    in_size: int = eqx.field(static=True)
    output_no_spike_value: float = eqx.field(static=True)
    solver: dfx.AbstractSolver = eqx.field(static=True)
    stepsize_controller: AbstractStepSizeController = eqx.field(static=True)
    output_indices: Int[Array, "n_out"]
    input_indices: Int[Array, "n_in"]
    use_delays: bool = eqx.field(static=True)
    use_axonal_delays: bool = eqx.field(static=True)
    spike_buffer: Type[SpikeBuffer] = eqx.field(static=True)
    adjoint: Any = eqx.field(static=True)
    root_finder: Any = eqx.field(static=True)
    adjoint: AbstractAdjoint = eqx.field(static=True)

    def __init__(
        self,
        neuron_model: NeuronModel,
        n_neurons: int,
        max_solver_time: float,
        in_size: int,
        key: PRNGKeyArray = None,
        t0: float = 0.0,
        wmask: Float[Array, "in_plus_neurons neurons"] = None,
        init_delays: Float[Array, "in_plus_neurons neurons"] = None,
        dlim: float = None,
        init_axonal_delays: Float[Array, "in_plus_neurons"] = None,
        axonal_dlim: float = None,
        output_neurons=None,
        input_neurons=None,
        buffer_capacity: int | None = None,
        max_event_steps: int = 1000,
        solver_stepsize: float = 0.001,
        output_no_spike_value: float = jnp.inf,
        root_finder=None,
        stepsize_controller=None,
        solver=None,
        adjoint=None,
        dtype=jnp.float32,
        **neuron_model_kwargs,
    ) -> None:

        self.use_delays = False

        if init_delays is None:
            if dlim is None:
                self.delays = None
            else:
                if key is None:
                    raise ValueError(
                        "Must set key to randomly initialize delays because init_delays is None and dlim is set"
                    )
                key, dkey = jax.random.split(key)
                self.delays = jax.random.uniform(
                    dkey,
                    (n_neurons + in_size, n_neurons),
                    minval=0,
                    maxval=dlim,
                    dtype=dtype,
                )
                self.use_delays = True

        elif isinstance(init_delays, (int, float)):
            self.delays = jnp.full(
                (n_neurons + in_size, n_neurons),
                init_delays,
                dtype=dtype,
            )
            self.use_delays = True
        else:
            self.delays = init_delays
            self.use_delays = True

        # Initialize axonal delays
        self.use_axonal_delays = False

        if init_axonal_delays is None:
            if axonal_dlim is None:
                self.axonal_delays = None
            else:
                if key is None:
                    raise ValueError(
                        "Must set key to randomly initialize axonal delays because"
                        "init_axonal_delays is None and axonal_dlim is set"
                    )
                key, akey = jax.random.split(key)
                self.axonal_delays = jax.random.uniform(
                    akey,
                    (n_neurons + in_size,),
                    minval=0,
                    maxval=axonal_dlim,
                    dtype=dtype,
                )
                self.use_axonal_delays = True

        elif isinstance(init_axonal_delays, (int, float)):
            self.axonal_delays = jnp.full(
                (n_neurons + in_size,),
                init_axonal_delays,
                dtype=dtype,
            )
            self.use_axonal_delays = True
        else:
            self.axonal_delays = init_axonal_delays
            self.use_axonal_delays = True

        # Determine if any delays are used (for buffer capacity logic)
        any_delays = self.use_delays or self.use_axonal_delays

        if not any_delays:

            if buffer_capacity is None:
                buffer_capacity = 1

            elif buffer_capacity != 1:
                warnings.warn(
                    "No synaptic or axonal delays are used, so buffer_capacity is forced to 1. "
                    "For simulations without delays, buffer capacity should be 1.",
                    stacklevel=2,
                )
                buffer_capacity = 1
        else:

            if buffer_capacity is None:
                buffer_capacity = 1000

        key, neuron_key = jax.random.split(key)

        self.spike_buffer = SpikeBuffer

        self.neuron_model = neuron_model(
            key=neuron_key,
            n_neurons=n_neurons,
            in_size=in_size,
            wmask=wmask,
            dtype=dtype,
            **neuron_model_kwargs,
        )

        ids_dtype, no_ids_value = self.spike_buffer.calc_dtype_and_non_spike_value(
            n_neurons + in_size
        )

        if output_neurons is None:
            output_neurons = jnp.ones((n_neurons,))
        self.output_indices = jnp.array(
            jnp.where(output_neurons)[0],
            dtype=ids_dtype,
        )

        if input_neurons is None:
            input_neurons = jnp.ones((n_neurons,))
        self.input_indices = jnp.array(
            jnp.where(input_neurons)[0],
            dtype=ids_dtype,
        )

        if wmask is None:
            wmask = jnp.ones((n_neurons + in_size, n_neurons))
            input_rows = jnp.arange(n_neurons, n_neurons + in_size)
            wmask = wmask.at[input_rows, :].set(0)
            wmask = wmask.at[input_rows, self.input_indices].set(1)

        expected_wmask_shape = (n_neurons + in_size, n_neurons)
        if wmask.shape != expected_wmask_shape:
            raise ValueError(
                f"wmask must have shape {expected_wmask_shape}, but got {wmask.shape}"
            )

        if input_neurons.shape != (n_neurons,):
            raise ValueError(
                f"input_neurons must have shape {(n_neurons,)}, but got {input_neurons.shape}"
            )

        if output_neurons.shape != (n_neurons,):
            raise ValueError(
                f"output_neurons must have shape {(n_neurons,)}, but got {output_neurons.shape}"
            )

        # Create syn_conn for ALL neurons (including input neurons)
        syn_conn_list = [jnp.where(wmask[i])[0] for i in range(n_neurons + in_size)]
        self.max_syn_conn = max(x.shape[0] for x in syn_conn_list) if syn_conn_list else 0

        def pad1d(arr):
            return jnp.pad(
                arr,
                (0, self.max_syn_conn - arr.shape[0]),
                constant_values=no_ids_value,
            )

        self.syn_conn = (
            jnp.stack([pad1d(ids) for ids in syn_conn_list], axis=0)
            .astype(ids_dtype)
        )

        if output_no_spike_value is None:
            self.output_no_spike_value = jnp.inf
        else:
            self.output_no_spike_value = output_no_spike_value

        self.n_neurons = n_neurons

        self.buffer_capacity = buffer_capacity
        self.solver_stepsize = solver_stepsize
        self.max_event_steps = max_event_steps
        self.max_solver_time = max_solver_time
        self.in_size = in_size
        self.t0 = t0
        self.dtype = dtype

        self.max_solver_steps = ceil(max_solver_time / solver_stepsize) + 1

        if root_finder is None:
            self.root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
        else:
            self.root_finder = root_finder

        if stepsize_controller is None:
            self.stepsize_controller = dfx.ConstantStepSize()
        else:
            self.stepsize_controller = stepsize_controller

        if solver is None:
            self.solver = dfx.Euler()
        else:
            self.solver = solver

        if adjoint is None:
            self.adjoint = dfx.RecursiveCheckpointAdjoint()
        else:
            self.adjoint = adjoint

    def _get_axonal_delay(self, neuron_idx):
        """Get axonal delay for a neuron, returning 0 if axonal delays are not used."""
        if self.use_axonal_delays:
            return jnp.maximum(self.axonal_delays[neuron_idx], 0.0)
        else:
            return 0.0

    def init_state(self) -> Any:
        state = self.neuron_model.init_state(self.n_neurons)

        def cast_leaf(x):
            if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
                return x.astype(self.dtype)
            return x

        return tree_util.tree_map(cast_leaf, state)

    def init_buffer(
        self,
        in_spike_times: Optional[Float[Array, "in_size K"]],
        comp_times: Optional[Float[Array, "n_times"]] = None,
    ):
        _, non_spike_idx = self.spike_buffer.calc_dtype_and_non_spike_value(
            self.n_neurons + self.in_size
        )

        if in_spike_times is None:
            times = jnp.array([self.t0], dtype=self.dtype)
            from_indices = jnp.array([non_spike_idx])
            to_indices = jnp.array([0])

            if comp_times is not None:
                comp_times = jnp.ravel(comp_times).astype(self.dtype)
                n_times = comp_times.shape[0]
                comp_from = jnp.full((n_times,), non_spike_idx, dtype=from_indices.dtype)
                comp_to = jnp.full((n_times,), non_spike_idx, dtype=to_indices.dtype)

                times = jnp.concatenate([times, comp_times], axis=0)
                from_indices = jnp.concatenate([from_indices, comp_from], axis=0)
                to_indices = jnp.concatenate([to_indices, comp_to], axis=0)

            return self.spike_buffer.init(
                self.buffer_capacity,
                self.n_neurons,
                times,
                from_indices,
                to_indices,
                time_dtype=self.dtype,
            )

        if in_spike_times.ndim != 2 or in_spike_times.shape[0] != self.in_size:
            raise ValueError(
                f"EvNN expects (input size, K spikes per input) but got {in_spike_times.shape}"
            )

        M = self.in_size    # number of input slots
        K = in_spike_times.shape[1]  # max spikes per slot
        N = self.input_indices.shape[0]  # first layer dimension
        base = self.n_neurons

        from_range = jnp.arange(base, base + M)  # Indices of the input neurons start after n_neurons
        to_range = self.input_indices

        if self.use_delays:
            from_indices = jnp.repeat(from_range, N * K)
            to_indices = jnp.tile(to_range, M * K)

            times = in_spike_times.ravel()
            times = jnp.repeat(times, N)

            # Add synaptic delays
            times = times + jnp.maximum(self.delays[from_indices, to_indices], 0.0)

            # Add axonal delays if enabled
            if self.use_axonal_delays:
                times = times + jnp.maximum(self.axonal_delays[from_indices], 0.0)

            inf_mask = jnp.isinf(times)
            to_indices = jnp.where(inf_mask, non_spike_idx, to_indices)
            from_indices = jnp.where(inf_mask, non_spike_idx, from_indices)

        else:
            # Non-delay case: create pseudospikes with from_indices only
            from_indices = jnp.repeat(from_range, K)
            times = in_spike_times.ravel()

            # Add axonal delays if enabled (even without synaptic delays)
            if self.use_axonal_delays:
                axonal_delay_per_spike = jnp.maximum(self.axonal_delays[from_indices], 0.0)
                times = times + axonal_delay_per_spike

            # Use non_spike_idx for to_indices to indicate these are non-delay pseudospikes
            to_indices = jnp.full_like(from_indices, non_spike_idx)

            # mask out inf spikes
            inf_mask = jnp.isinf(times)
            from_indices = jnp.where(inf_mask, non_spike_idx, from_indices)

        # Add initial pseudospike for starting integration
        # This is distinct from non-delay pseudospikes (has from_idx=non_spike_idx, to_idx=0)
        times = jnp.concatenate((jnp.array([self.t0], dtype=self.dtype), times), axis=0)
        from_indices = jnp.concatenate((jnp.array([non_spike_idx]), from_indices), axis=0)
        to_indices = jnp.concatenate((jnp.array([0]), to_indices), axis=0)

        # Append comp_times as state_at_t pseudospikes if given
        if comp_times is not None:
            comp_times = jnp.ravel(comp_times).astype(self.dtype)
            n_times = comp_times.shape[0]
            comp_from = jnp.full((n_times,), non_spike_idx, dtype=from_indices.dtype)
            comp_to = jnp.full((n_times,), non_spike_idx, dtype=to_indices.dtype)

            times = jnp.concatenate([times, comp_times], axis=0)
            from_indices = jnp.concatenate([from_indices, comp_from], axis=0)
            to_indices = jnp.concatenate([to_indices, comp_to], axis=0)

        return self.spike_buffer.init(
            self.buffer_capacity,
            self.n_neurons + self.in_size,
            times,
            from_indices,
            to_indices,
            time_dtype=self.dtype,
        )

    def __call__(
        self,
        state: Any,  # PyTree
        buffer,
    ) -> Tuple[Float[Array, ""],
               Bool[Array, "neurons"],
               Any,
               eqx.Module]:
        """Integrate state between events (buffer spike or neuron spike)."""
        # peek at next event time
        t0, _, _ = self.spike_buffer.peek(buffer)

        def no_event(buf):
            # no spikes to integrate: return unchanged buffer
            return (
                jnp.minimum(t0, self.max_solver_time),
                jnp.zeros((self.n_neurons,), dtype=bool),
                state,
                buf,
            )

        def handle_event(buf):
            # pop one spike out, call it buf1
            t0, from_idx, to_idx, buf1 = self.spike_buffer.pop(buf)
            t1, _, _ = self.spike_buffer.peek(buf1)
            t1 = jnp.minimum(t1, self.max_solver_time)
            t0_clamped = jnp.minimum(t0, t1)

            # Handle different types of spikes/pseudospikes
            def handle_spike_input(args):
                s, f_idx, t_idx = args
                ns = buf.index_non_spike_value

                # Check if this is a non-delay spike (has from_idx, to_idx = non_spike)
                is_non_delay = (t_idx == ns) & (f_idx != ns)

                # Check if this is a state_at_t pseudospike (both indices are non_spike)
                is_state_pseudospike = (f_idx == ns) & (t_idx == ns)

                # Check if this is an init pseudospike (from_idx=non_spike, to_idx=0)
                is_init_pseudospike = (f_idx == ns) & (t_idx == 0)

                def process_non_delay():
                    # For non-delay: get all connections from the neuron and apply them
                    conn_to = self.syn_conn[f_idx]
                    valid = conn_to != ns
                    return self.neuron_model.input_spike(s, f_idx, conn_to, valid)

                def process_regular():
                    # Regular delayed spike
                    return self.neuron_model.input_spike(
                        s,
                        f_idx,
                        jnp.array([t_idx]),
                        jnp.array([True])
                    )

                def no_change():
                    return s

                # Chain of conditions to handle different spike types
                return jax.lax.cond(
                    is_state_pseudospike | is_init_pseudospike,
                    no_change,
                    lambda: jax.lax.cond(
                        is_non_delay,
                        process_non_delay,
                        process_regular,
                    ),
                )

            state1 = handle_spike_input((state, from_idx, to_idx))

            def integrate(buf_inner):

                def spike_cond(t, y, args, **kwargs):
                    return jnp.max(self.neuron_model.spike_condition(t, y)).astype(self.dtype)

                event = dfx.Event(spike_cond, self.root_finder, direction=True)

                # run ODE solve between t0 and t1 on state1
                sol = dfx.diffeqsolve(
                    dfx.ODETerm(self.neuron_model),
                    self.solver,
                    stepsize_controller=self.stepsize_controller,
                    t0=t0_clamped,
                    t1=t1,
                    dt0=self.solver_stepsize,
                    y0=state1,
                    event=event,
                    throw=True,
                    max_steps=self.max_solver_steps,
                    adjoint=self.adjoint,
                    saveat=dfx.SaveAt(t0=False, t1=True, steps=False, dense=False),
                )
                t_spike = sol.ts[-1]
                y_spike = tree_util.tree_map(lambda x: x[0], sol.ys) if sol.ts.shape[0] == 1 else sol.ys[-1]

                cond_now = self.neuron_model.spike_condition(t_spike, y_spike)
                # spike_mask = (jnp.max(cond_now) == cond_now) & sol.event_mask
                spiked = jnp.argmax(cond_now)
                spike_mask = jnp.zeros_like(cond_now, dtype=bool).at[spiked].set(sol.event_mask)
                state2 = self.neuron_model.reset_spiked(y_spike, spike_mask)

                # if any spikes, add them to buffer
                def add_spikes(op):
                    state_, b = op
                    spiked = jnp.argmax(spike_mask)

                    # NEW: get dtypes from internal_state
                    idx_dtype_from = b.internal_state.from_indices.dtype
                    idx_dtype_to = b.internal_state.to_indices.dtype

                    conn_to = self.syn_conn[spiked].astype(idx_dtype_to)
                    valid = conn_to != b.index_non_spike_value

                    ns = b.index_non_spike_value

                    # Get axonal delay for the spiking neuron (0 if not used)
                    axonal_delay = self._get_axonal_delay(spiked)

                    # Add pseudospike to continue integration at the right time
                    curr_time = jnp.array([t_spike])
                    curr_from = jnp.array([ns], dtype=idx_dtype_from)
                    curr_to = jnp.array([0], dtype=idx_dtype_to)

                    if self.use_delays:
                        # Synaptic delays + optional axonal delay
                        delayed_times = jnp.where(
                            valid,
                            t_spike + axonal_delay + jnp.maximum(self.delays[spiked, conn_to], 0.0),
                            jnp.inf,
                        )

                        delayed_from = jnp.where(valid, spiked, ns).astype(idx_dtype_from)
                        delayed_to = conn_to

                        # concat current time pseudospike and delayed times
                        all_times = jnp.concatenate((curr_time, delayed_times), axis=0)
                        all_from = jnp.concatenate((curr_from, delayed_from), axis=0)
                        all_to = jnp.concatenate((curr_to, delayed_to), axis=0)

                        # mask out spikes that exceed max_solver_time
                        time_mask = all_times < self.max_solver_time
                        all_times = jnp.where(time_mask, all_times, jnp.inf)
                        all_from = jnp.where(time_mask, all_from, ns)
                        all_to = jnp.where(time_mask, all_to, ns)

                        return state_, self.spike_buffer.add_multiple(b, all_times, all_from, all_to)

                    else:
                        # Non-delay case: add a pseudospike with the spiked neuron as from_idx
                        # Include axonal delay if enabled
                        spike_time = t_spike + axonal_delay
                        non_delay_from = jnp.array([spiked], dtype=idx_dtype_from)
                        non_delay_to = jnp.array([ns], dtype=idx_dtype_to)

                        # If axonal delay is non-zero, we need a continuation pseudospike
                        # at t_spike to keep integration going until the delayed spike arrives
                        if self.use_axonal_delays:
                            all_times = jnp.concatenate((curr_time, jnp.array([spike_time])), axis=0)
                            all_from = jnp.concatenate((curr_from, non_delay_from), axis=0)
                            all_to = jnp.concatenate((curr_to, non_delay_to), axis=0)
                            return state_, self.spike_buffer.add_multiple(b, all_times, all_from, all_to)
                        else:
                            return state_, self.spike_buffer.add(
                                b, spike_time, non_delay_from[0], non_delay_to[0]
                            )

                # check if any neuron spiked and if new spikes need to be generated
                any_spike = (jnp.sum(spike_mask.astype(jnp.int32)) > 0)
                state2, new_buf = jax.lax.cond(
                    any_spike,
                    add_spikes,
                    lambda op: op,
                    (state2, buf_inner),
                )
                return t_spike, spike_mask, state2, new_buf

            # choose between integrate or skip if t0 == t1
            return jax.lax.cond(
                t0_clamped < t1,
                integrate,
                lambda b: (t0_clamped, jnp.zeros((self.n_neurons,), bool), state1, b),
                buf1,
            )

        # if there is no spike in the buffer -> do no-op
        t_spike, spike_mask, state_final, buffer_final = jax.lax.cond(
            t0 != jnp.inf,
            handle_event,
            no_event,
            buffer,
        )

        return t_spike, spike_mask, state_final, buffer_final

    def ttfs(self, in_spike_times: Float[Array, "in_size K"]) -> Float[Array, "n_out"]:
        """For each output neuron returns the time it first fired for a given input."""

        in_spike_times = in_spike_times.astype(self.dtype)

        n_outputs = len(self.output_indices)

        def cond_fn(carry):
            t_curr, _, _, spike_buffer, first_spike_times_out = carry
            all_spiked = jnp.all(first_spike_times_out < self.output_no_spike_value)
            time_left = t_curr < self.max_solver_time
            return jnp.logical_and(jnp.logical_not(all_spiked), time_left)

        def body_fn(carry):
            t_spike, m_spike, state, spike_buffer, first_spike_times_out = carry

            t_spike_new, m_spike_new, state_new, spike_buffer_new = self(state, spike_buffer)

            first_spike_times_out_new = jnp.where(
                ((first_spike_times_out == self.output_no_spike_value) &
                 (m_spike_new[self.output_indices] > 0)),
                t_spike_new,
                first_spike_times_out,
            )

            return (t_spike_new, m_spike_new, state_new, spike_buffer_new, first_spike_times_out_new)

        init_carry = (
            0.0,
            jnp.zeros((self.n_neurons,), dtype=jnp.bool_),
            self.init_state(),
            self.init_buffer(in_spike_times),
            jnp.full((n_outputs,), self.output_no_spike_value, dtype=self.dtype),
        )

        out_carry = eqx.internal.while_loop(
            cond_fn, body_fn, init_carry, max_steps=self.max_event_steps, kind="bounded"
        )

        out_carry = eqx.error_if(
            out_carry, cond_fn(out_carry), "Reached max event steps. Try to increase event_steps."
        )

        _, _, _, _, first_spike_times = out_carry
        return first_spike_times

    def spikes_until_t(
        self,
        in_spike_times: Float[Array, "in_size K"],
        final_time: float,
        max_spikes: int = 100,
    ) -> Float[Array, "n_out max_spikes"]:
        n_outputs = len(self.output_indices)

        def cond_fn(carry):
            state, last_t, buffer, out_spikes, counter = carry
            max_spikes_reached = jnp.sum(counter) >= max_spikes
            empty_buffer = self.spike_buffer.is_empty(buffer)
            final_time_reached = last_t >= final_time
            return ~(max_spikes_reached | empty_buffer | final_time_reached)

        def body_fn(carry):
            state, _, buffer, out_spikes, counter = carry
            t_spike, m_spike, state_new, buffer_new = self(state, buffer)
            valid_time = t_spike <= final_time
            mask_out = m_spike[self.output_indices] & valid_time
            i = jnp.arange(n_outputs)
            slot = counter
            new_vals = jnp.where(mask_out, t_spike, out_spikes[i, slot])
            out_spikes = out_spikes.at[i, slot].set(new_vals)
            counter = counter + mask_out.astype(jnp.int32)
            return state_new, t_spike, buffer_new, out_spikes, counter

        init_state = self.init_state()
        init_buffer = self.init_buffer(in_spike_times)
        out_spikes = jnp.full((n_outputs, max_spikes), self.output_no_spike_value)
        init_counter = jnp.zeros((n_outputs,), dtype=jnp.int32)
        init_carry = (init_state, self.t0, init_buffer, out_spikes, init_counter)

        out_carry = eqx.internal.while_loop(
            cond_fn,
            body_fn,
            init_carry,
            max_steps=self.max_event_steps,
            kind="bounded",
        )
        out_carry = eqx.error_if(
            out_carry, cond_fn(out_carry), "Reached max event steps. Try to increase event_steps."
        )

        _, _, _, out_spikes, _ = out_carry

        return out_spikes

    def state_at_t(
        self,
        in_spike_times: Float[Array, "in_size K"],
        comp_times: Float[Array, "n_times"],
    ) -> Float[Array, "n_out n_times obs_channels"]:

        comp_times = jnp.ravel(comp_times)
        n_times = comp_times.shape[0]
        n_out = len(self.output_indices)

        init_state = self.init_state()
        sample_obs = self.neuron_model.observe(init_state)
        obs_dim = sample_obs.shape[-1]
        obs_dtype = sample_obs.dtype

        # init buffer with inputs + t0 pseudospike + comp_times pseudospikes
        buf = self.init_buffer(in_spike_times, comp_times)

        acc = jnp.full(
            (n_times, n_out, obs_dim),
            jnp.nan,
            dtype=obs_dtype,
        )

        def cond_fn(carry):
            _, buf, acc, _ = carry
            return jnp.any(jnp.isnan(acc)) & (~self.spike_buffer.is_empty(buf))

        def body_fn(carry):
            state, buf, acc, cnt = carry

            t, i1, i2 = self.spike_buffer.peek(buf)

            is_comp = jnp.logical_and(
                i1 == buf.index_non_spike_value,
                i2 == buf.index_non_spike_value,
            )

            def write_obs(a):
                obs = self.neuron_model.observe(state)
                return a.at[cnt].set(obs[self.output_indices])

            acc = jax.lax.cond(
                is_comp,
                write_obs,
                lambda a: a,
                acc,
            )

            _, _, state, buf = self(state, buf)

            cnt += is_comp

            return (state, buf, acc, cnt)

        init_carry = (init_state, buf, acc, 0)

        out_carry = eqx.internal.while_loop(
            cond_fn,
            body_fn,
            init_carry,
            max_steps=self.max_event_steps,
            kind="bounded",
        )

        out_carry = eqx.error_if(
            out_carry, cond_fn(out_carry), "Reached max event steps. Try to increase event_steps."
        )

        _, _, filled, _ = out_carry
        # transpose to (n_outputs, n_times, obs_dim)
        return filled.transpose((1, 0, 2))

    def record(self, in_spike_times: Float[Array, "in_size K"]):

        in_spike_times = in_spike_times.astype(self.dtype)

        init_state = self.init_state()
        buf0 = self.init_buffer(in_spike_times)
        idx_dtype = buf0.internal_state.from_indices.dtype
        no_id = buf0.index_non_spike_value

        def cond_fn(carry):
            t, m, state, buf, rec_t, rec_id, rec_buf, step = carry
            not_empty = jnp.logical_not(self.spike_buffer.is_empty(buf))
            steps_ok = step < self.max_event_steps
            return jnp.logical_and(not_empty, steps_ok)

        def body_fn(carry):
            t, m, state, buf, rec_t, rec_id, rec_buf, step = carry

            t_new, m_new, state_new, buf_new = self(state, buf)

            n_spikes_any = jnp.sum(m_new.astype(jnp.int32))
            did_spike = n_spikes_any == 1

            spike_id = jnp.argmax(m_new).astype(idx_dtype)

            rec_t = rec_t.at[step].set(jnp.where(did_spike, t_new, rec_t[step]))
            rec_id = rec_id.at[step].set(jnp.where(did_spike, spike_id, rec_id[step]))

            buf_size = jnp.asarray(self.spike_buffer.size(buf_new), dtype=jnp.int32)
            rec_buf = rec_buf.at[step].set(buf_size)

            return (t_new, m_new, state_new, buf_new, rec_t, rec_id, rec_buf, step + 1)

        init_carry = (
            jnp.array(0.0, dtype=self.dtype),
            jnp.zeros((self.n_neurons,), dtype=jnp.bool_),
            init_state,
            buf0,
            jnp.full((self.max_event_steps,), self.output_no_spike_value,
                     dtype=self.dtype),
            jnp.full((self.max_event_steps,), no_id, dtype=idx_dtype),
            jnp.full((self.max_event_steps,), -1, dtype=jnp.int32),
            jnp.array(0, dtype=jnp.int32),
        )

        out = eqx.internal.while_loop(
            cond_fn, body_fn, init_carry, max_steps=self.max_event_steps, kind="bounded"
        )

        _, _, _, _, recorded_spike_times, recorded_spike_ids, recorded_buffer_sizes, _ = out
        return recorded_spike_times, recorded_spike_ids, recorded_buffer_sizes

    def get_wmask(self) -> Float[Array, "in_plus_neurons neurons"]:

        wmask = jnp.zeros((self.n_neurons + self.in_size, self.n_neurons), dtype=self.dtype)
        _, no_ids_value = self.spike_buffer.calc_dtype_and_non_spike_value(
            self.n_neurons + self.in_size
        )

        # Handle all neurons (including input neurons)
        valid = self.syn_conn != no_ids_value
        cols = jnp.where(valid, self.syn_conn, 0)
        rows = jnp.broadcast_to(
            jnp.arange(self.n_neurons + self.in_size)[:, None],
            self.syn_conn.shape,
        )

        rows_flat = rows[valid]
        cols_flat = cols[valid]
        wmask = wmask.at[rows_flat, cols_flat].set(1)

        return wmask

Parameters

Parameter Description Default
neuron_model Class (not instance) implementing the NeuronModel interface. Constructed internally using key, n_neurons, in_size, dtype, and any extra **neuron_model_kwargs. —
n_neurons Number of neurons in the network. —
max_solver_time Hard stop time for the simulation; spikes beyond this time are discarded. —
in_size Number of input slots. Each slot may carry multiple input spikes. —
key Random key for initialising the neuron model. None
t0 Simulation start time. 0.0
wmask Connectivity mask from all senders (neurons + inputs) to neurons, shape \((N+K) \times N\). If omitted, defaults to all-to-all among neurons plus inputs → selected first layer (input_neurons). None
output_neurons Boolean mask \((N,)\) indicating which neurons are output neurons. None (all)
input_neurons Boolean mask \((N,)\) selecting neurons that receive external input spikes. None
dtype Numeric type for event times and dynamics. jnp.float32
**neuron_model_kwargs Additional keyword arguments passed to the neuron model constructor. —

Solver configuration

Parameter Description Default
solver ODE solver (any diffrax.AbstractSolver). diffrax.Euler()
stepsize_controller Step-size control strategy. diffrax.ConstantStepSize()
solver_stepsize Initial step size for the ODE solver. 1e-3
root_finder Root finder used in diffrax.Event to locate spike times. optimistix.Newton(1e-2, 1e-2, rms_norm)
adjoint Adjoint method for differentiating through diffeqsolve. diffrax.RecursiveCheckpointAdjoint()

Simulation bounds

Parameter Description Default
buffer_capacity Capacity of the internal spike buffer. 1
max_event_steps Upper bound on the number of event-loop iterations. 1000
output_no_spike_value Fill value for "no spike yet" in output arrays. jnp.inf

Methods

init_state

Source code in eventax/evnn.py
268
269
270
271
272
273
274
275
276
def init_state(self) -> Any:
    state = self.neuron_model.init_state(self.n_neurons)

    def cast_leaf(x):
        if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
            return x.astype(self.dtype)
        return x

    return tree_util.tree_map(cast_leaf, state)

Returns the initial neuron state from the NeuronModel, cast to the network's dtype.


init_buffer

Source code in eventax/evnn.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def init_buffer(
    self,
    in_spike_times: Optional[Float[Array, "in_size K"]],
    comp_times: Optional[Float[Array, "n_times"]] = None,
):
    _, non_spike_idx = self.spike_buffer.calc_dtype_and_non_spike_value(
        self.n_neurons + self.in_size
    )

    if in_spike_times is None:
        times = jnp.array([self.t0], dtype=self.dtype)
        from_indices = jnp.array([non_spike_idx])
        to_indices = jnp.array([0])

        if comp_times is not None:
            comp_times = jnp.ravel(comp_times).astype(self.dtype)
            n_times = comp_times.shape[0]
            comp_from = jnp.full((n_times,), non_spike_idx, dtype=from_indices.dtype)
            comp_to = jnp.full((n_times,), non_spike_idx, dtype=to_indices.dtype)

            times = jnp.concatenate([times, comp_times], axis=0)
            from_indices = jnp.concatenate([from_indices, comp_from], axis=0)
            to_indices = jnp.concatenate([to_indices, comp_to], axis=0)

        return self.spike_buffer.init(
            self.buffer_capacity,
            self.n_neurons,
            times,
            from_indices,
            to_indices,
            time_dtype=self.dtype,
        )

    if in_spike_times.ndim != 2 or in_spike_times.shape[0] != self.in_size:
        raise ValueError(
            f"EvNN expects (input size, K spikes per input) but got {in_spike_times.shape}"
        )

    M = self.in_size    # number of input slots
    K = in_spike_times.shape[1]  # max spikes per slot
    N = self.input_indices.shape[0]  # first layer dimension
    base = self.n_neurons

    from_range = jnp.arange(base, base + M)  # Indices of the input neurons start after n_neurons
    to_range = self.input_indices

    if self.use_delays:
        from_indices = jnp.repeat(from_range, N * K)
        to_indices = jnp.tile(to_range, M * K)

        times = in_spike_times.ravel()
        times = jnp.repeat(times, N)

        # Add synaptic delays
        times = times + jnp.maximum(self.delays[from_indices, to_indices], 0.0)

        # Add axonal delays if enabled
        if self.use_axonal_delays:
            times = times + jnp.maximum(self.axonal_delays[from_indices], 0.0)

        inf_mask = jnp.isinf(times)
        to_indices = jnp.where(inf_mask, non_spike_idx, to_indices)
        from_indices = jnp.where(inf_mask, non_spike_idx, from_indices)

    else:
        # Non-delay case: create pseudospikes with from_indices only
        from_indices = jnp.repeat(from_range, K)
        times = in_spike_times.ravel()

        # Add axonal delays if enabled (even without synaptic delays)
        if self.use_axonal_delays:
            axonal_delay_per_spike = jnp.maximum(self.axonal_delays[from_indices], 0.0)
            times = times + axonal_delay_per_spike

        # Use non_spike_idx for to_indices to indicate these are non-delay pseudospikes
        to_indices = jnp.full_like(from_indices, non_spike_idx)

        # mask out inf spikes
        inf_mask = jnp.isinf(times)
        from_indices = jnp.where(inf_mask, non_spike_idx, from_indices)

    # Add initial pseudospike for starting integration
    # This is distinct from non-delay pseudospikes (has from_idx=non_spike_idx, to_idx=0)
    times = jnp.concatenate((jnp.array([self.t0], dtype=self.dtype), times), axis=0)
    from_indices = jnp.concatenate((jnp.array([non_spike_idx]), from_indices), axis=0)
    to_indices = jnp.concatenate((jnp.array([0]), to_indices), axis=0)

    # Append comp_times as state_at_t pseudospikes if given
    if comp_times is not None:
        comp_times = jnp.ravel(comp_times).astype(self.dtype)
        n_times = comp_times.shape[0]
        comp_from = jnp.full((n_times,), non_spike_idx, dtype=from_indices.dtype)
        comp_to = jnp.full((n_times,), non_spike_idx, dtype=to_indices.dtype)

        times = jnp.concatenate([times, comp_times], axis=0)
        from_indices = jnp.concatenate([from_indices, comp_from], axis=0)
        to_indices = jnp.concatenate([to_indices, comp_to], axis=0)

    return self.spike_buffer.init(
        self.buffer_capacity,
        self.n_neurons + self.in_size,
        times,
        from_indices,
        to_indices,
        time_dtype=self.dtype,
    )

Builds a spike buffer for the simulation. in_spike_times must have shape (in_size, K) where K is the maximum number of spikes per input slot. Use jnp.inf for unused slots.

If comp_times is provided, pseudospike events are inserted at those times so that state_at_t can record the neuron state.


__call__

Integrate state between events (buffer spike or neuron spike).

Source code in eventax/evnn.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def __call__(
    self,
    state: Any,  # PyTree
    buffer,
) -> Tuple[Float[Array, ""],
           Bool[Array, "neurons"],
           Any,
           eqx.Module]:
    """Integrate state between events (buffer spike or neuron spike)."""
    # peek at next event time
    t0, _, _ = self.spike_buffer.peek(buffer)

    def no_event(buf):
        # no spikes to integrate: return unchanged buffer
        return (
            jnp.minimum(t0, self.max_solver_time),
            jnp.zeros((self.n_neurons,), dtype=bool),
            state,
            buf,
        )

    def handle_event(buf):
        # pop one spike out, call it buf1
        t0, from_idx, to_idx, buf1 = self.spike_buffer.pop(buf)
        t1, _, _ = self.spike_buffer.peek(buf1)
        t1 = jnp.minimum(t1, self.max_solver_time)
        t0_clamped = jnp.minimum(t0, t1)

        # Handle different types of spikes/pseudospikes
        def handle_spike_input(args):
            s, f_idx, t_idx = args
            ns = buf.index_non_spike_value

            # Check if this is a non-delay spike (has from_idx, to_idx = non_spike)
            is_non_delay = (t_idx == ns) & (f_idx != ns)

            # Check if this is a state_at_t pseudospike (both indices are non_spike)
            is_state_pseudospike = (f_idx == ns) & (t_idx == ns)

            # Check if this is an init pseudospike (from_idx=non_spike, to_idx=0)
            is_init_pseudospike = (f_idx == ns) & (t_idx == 0)

            def process_non_delay():
                # For non-delay: get all connections from the neuron and apply them
                conn_to = self.syn_conn[f_idx]
                valid = conn_to != ns
                return self.neuron_model.input_spike(s, f_idx, conn_to, valid)

            def process_regular():
                # Regular delayed spike
                return self.neuron_model.input_spike(
                    s,
                    f_idx,
                    jnp.array([t_idx]),
                    jnp.array([True])
                )

            def no_change():
                return s

            # Chain of conditions to handle different spike types
            return jax.lax.cond(
                is_state_pseudospike | is_init_pseudospike,
                no_change,
                lambda: jax.lax.cond(
                    is_non_delay,
                    process_non_delay,
                    process_regular,
                ),
            )

        state1 = handle_spike_input((state, from_idx, to_idx))

        def integrate(buf_inner):

            def spike_cond(t, y, args, **kwargs):
                return jnp.max(self.neuron_model.spike_condition(t, y)).astype(self.dtype)

            event = dfx.Event(spike_cond, self.root_finder, direction=True)

            # run ODE solve between t0 and t1 on state1
            sol = dfx.diffeqsolve(
                dfx.ODETerm(self.neuron_model),
                self.solver,
                stepsize_controller=self.stepsize_controller,
                t0=t0_clamped,
                t1=t1,
                dt0=self.solver_stepsize,
                y0=state1,
                event=event,
                throw=True,
                max_steps=self.max_solver_steps,
                adjoint=self.adjoint,
                saveat=dfx.SaveAt(t0=False, t1=True, steps=False, dense=False),
            )
            t_spike = sol.ts[-1]
            y_spike = tree_util.tree_map(lambda x: x[0], sol.ys) if sol.ts.shape[0] == 1 else sol.ys[-1]

            cond_now = self.neuron_model.spike_condition(t_spike, y_spike)
            # spike_mask = (jnp.max(cond_now) == cond_now) & sol.event_mask
            spiked = jnp.argmax(cond_now)
            spike_mask = jnp.zeros_like(cond_now, dtype=bool).at[spiked].set(sol.event_mask)
            state2 = self.neuron_model.reset_spiked(y_spike, spike_mask)

            # if any spikes, add them to buffer
            def add_spikes(op):
                state_, b = op
                spiked = jnp.argmax(spike_mask)

                # NEW: get dtypes from internal_state
                idx_dtype_from = b.internal_state.from_indices.dtype
                idx_dtype_to = b.internal_state.to_indices.dtype

                conn_to = self.syn_conn[spiked].astype(idx_dtype_to)
                valid = conn_to != b.index_non_spike_value

                ns = b.index_non_spike_value

                # Get axonal delay for the spiking neuron (0 if not used)
                axonal_delay = self._get_axonal_delay(spiked)

                # Add pseudospike to continue integration at the right time
                curr_time = jnp.array([t_spike])
                curr_from = jnp.array([ns], dtype=idx_dtype_from)
                curr_to = jnp.array([0], dtype=idx_dtype_to)

                if self.use_delays:
                    # Synaptic delays + optional axonal delay
                    delayed_times = jnp.where(
                        valid,
                        t_spike + axonal_delay + jnp.maximum(self.delays[spiked, conn_to], 0.0),
                        jnp.inf,
                    )

                    delayed_from = jnp.where(valid, spiked, ns).astype(idx_dtype_from)
                    delayed_to = conn_to

                    # concat current time pseudospike and delayed times
                    all_times = jnp.concatenate((curr_time, delayed_times), axis=0)
                    all_from = jnp.concatenate((curr_from, delayed_from), axis=0)
                    all_to = jnp.concatenate((curr_to, delayed_to), axis=0)

                    # mask out spikes that exceed max_solver_time
                    time_mask = all_times < self.max_solver_time
                    all_times = jnp.where(time_mask, all_times, jnp.inf)
                    all_from = jnp.where(time_mask, all_from, ns)
                    all_to = jnp.where(time_mask, all_to, ns)

                    return state_, self.spike_buffer.add_multiple(b, all_times, all_from, all_to)

                else:
                    # Non-delay case: add a pseudospike with the spiked neuron as from_idx
                    # Include axonal delay if enabled
                    spike_time = t_spike + axonal_delay
                    non_delay_from = jnp.array([spiked], dtype=idx_dtype_from)
                    non_delay_to = jnp.array([ns], dtype=idx_dtype_to)

                    # If axonal delay is non-zero, we need a continuation pseudospike
                    # at t_spike to keep integration going until the delayed spike arrives
                    if self.use_axonal_delays:
                        all_times = jnp.concatenate((curr_time, jnp.array([spike_time])), axis=0)
                        all_from = jnp.concatenate((curr_from, non_delay_from), axis=0)
                        all_to = jnp.concatenate((curr_to, non_delay_to), axis=0)
                        return state_, self.spike_buffer.add_multiple(b, all_times, all_from, all_to)
                    else:
                        return state_, self.spike_buffer.add(
                            b, spike_time, non_delay_from[0], non_delay_to[0]
                        )

            # check if any neuron spiked and if new spikes need to be generated
            any_spike = (jnp.sum(spike_mask.astype(jnp.int32)) > 0)
            state2, new_buf = jax.lax.cond(
                any_spike,
                add_spikes,
                lambda op: op,
                (state2, buf_inner),
            )
            return t_spike, spike_mask, state2, new_buf

        # choose between integrate or skip if t0 == t1
        return jax.lax.cond(
            t0_clamped < t1,
            integrate,
            lambda b: (t0_clamped, jnp.zeros((self.n_neurons,), bool), state1, b),
            buf1,
        )

    # if there is no spike in the buffer -> do no-op
    t_spike, spike_mask, state_final, buffer_final = jax.lax.cond(
        t0 != jnp.inf,
        handle_event,
        no_event,
        buffer,
    )

    return t_spike, spike_mask, state_final, buffer_final

Processes a single event-integration window:

  1. Pops the next event from the buffer.
  2. Applies any incoming spike to the neuron state via [input_spike][eventax.neuron_models.NeuronModel.input_spike].
  3. Integrates the ODE from the current event time to the next event (or until a neuron spikes).
  4. If a neuron spiked, resets its state via [reset_spiked][eventax.neuron_models.NeuronModel.reset_spiked] and enqueues the new spike event.

Returns a tuple (t_event, spike_mask, new_state, new_buffer).


ttfs

For each output neuron returns the time it first fired for a given input.

Source code in eventax/evnn.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
def ttfs(self, in_spike_times: Float[Array, "in_size K"]) -> Float[Array, "n_out"]:
    """For each output neuron returns the time it first fired for a given input."""

    in_spike_times = in_spike_times.astype(self.dtype)

    n_outputs = len(self.output_indices)

    def cond_fn(carry):
        t_curr, _, _, spike_buffer, first_spike_times_out = carry
        all_spiked = jnp.all(first_spike_times_out < self.output_no_spike_value)
        time_left = t_curr < self.max_solver_time
        return jnp.logical_and(jnp.logical_not(all_spiked), time_left)

    def body_fn(carry):
        t_spike, m_spike, state, spike_buffer, first_spike_times_out = carry

        t_spike_new, m_spike_new, state_new, spike_buffer_new = self(state, spike_buffer)

        first_spike_times_out_new = jnp.where(
            ((first_spike_times_out == self.output_no_spike_value) &
             (m_spike_new[self.output_indices] > 0)),
            t_spike_new,
            first_spike_times_out,
        )

        return (t_spike_new, m_spike_new, state_new, spike_buffer_new, first_spike_times_out_new)

    init_carry = (
        0.0,
        jnp.zeros((self.n_neurons,), dtype=jnp.bool_),
        self.init_state(),
        self.init_buffer(in_spike_times),
        jnp.full((n_outputs,), self.output_no_spike_value, dtype=self.dtype),
    )

    out_carry = eqx.internal.while_loop(
        cond_fn, body_fn, init_carry, max_steps=self.max_event_steps, kind="bounded"
    )

    out_carry = eqx.error_if(
        out_carry, cond_fn(out_carry), "Reached max event steps. Try to increase event_steps."
    )

    _, _, _, _, first_spike_times = out_carry
    return first_spike_times

Computes time-to-first-spike for each output neuron. Iterates the event loop until all outputs have spiked or max_solver_time is reached. Returns an array of shape (n_outputs,) filled with output_no_spike_value for neurons that did not spike.


spikes_until_t

Source code in eventax/evnn.py
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
def spikes_until_t(
    self,
    in_spike_times: Float[Array, "in_size K"],
    final_time: float,
    max_spikes: int = 100,
) -> Float[Array, "n_out max_spikes"]:
    n_outputs = len(self.output_indices)

    def cond_fn(carry):
        state, last_t, buffer, out_spikes, counter = carry
        max_spikes_reached = jnp.sum(counter) >= max_spikes
        empty_buffer = self.spike_buffer.is_empty(buffer)
        final_time_reached = last_t >= final_time
        return ~(max_spikes_reached | empty_buffer | final_time_reached)

    def body_fn(carry):
        state, _, buffer, out_spikes, counter = carry
        t_spike, m_spike, state_new, buffer_new = self(state, buffer)
        valid_time = t_spike <= final_time
        mask_out = m_spike[self.output_indices] & valid_time
        i = jnp.arange(n_outputs)
        slot = counter
        new_vals = jnp.where(mask_out, t_spike, out_spikes[i, slot])
        out_spikes = out_spikes.at[i, slot].set(new_vals)
        counter = counter + mask_out.astype(jnp.int32)
        return state_new, t_spike, buffer_new, out_spikes, counter

    init_state = self.init_state()
    init_buffer = self.init_buffer(in_spike_times)
    out_spikes = jnp.full((n_outputs, max_spikes), self.output_no_spike_value)
    init_counter = jnp.zeros((n_outputs,), dtype=jnp.int32)
    init_carry = (init_state, self.t0, init_buffer, out_spikes, init_counter)

    out_carry = eqx.internal.while_loop(
        cond_fn,
        body_fn,
        init_carry,
        max_steps=self.max_event_steps,
        kind="bounded",
    )
    out_carry = eqx.error_if(
        out_carry, cond_fn(out_carry), "Reached max event steps. Try to increase event_steps."
    )

    _, _, _, out_spikes, _ = out_carry

    return out_spikes

Records up to max_spikes spike times per output neuron up to final_time. Returns an array of shape (n_outputs, max_spikes) filled with output_no_spike_value where unused.


state_at_t

Source code in eventax/evnn.py
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
def state_at_t(
    self,
    in_spike_times: Float[Array, "in_size K"],
    comp_times: Float[Array, "n_times"],
) -> Float[Array, "n_out n_times obs_channels"]:

    comp_times = jnp.ravel(comp_times)
    n_times = comp_times.shape[0]
    n_out = len(self.output_indices)

    init_state = self.init_state()
    sample_obs = self.neuron_model.observe(init_state)
    obs_dim = sample_obs.shape[-1]
    obs_dtype = sample_obs.dtype

    # init buffer with inputs + t0 pseudospike + comp_times pseudospikes
    buf = self.init_buffer(in_spike_times, comp_times)

    acc = jnp.full(
        (n_times, n_out, obs_dim),
        jnp.nan,
        dtype=obs_dtype,
    )

    def cond_fn(carry):
        _, buf, acc, _ = carry
        return jnp.any(jnp.isnan(acc)) & (~self.spike_buffer.is_empty(buf))

    def body_fn(carry):
        state, buf, acc, cnt = carry

        t, i1, i2 = self.spike_buffer.peek(buf)

        is_comp = jnp.logical_and(
            i1 == buf.index_non_spike_value,
            i2 == buf.index_non_spike_value,
        )

        def write_obs(a):
            obs = self.neuron_model.observe(state)
            return a.at[cnt].set(obs[self.output_indices])

        acc = jax.lax.cond(
            is_comp,
            write_obs,
            lambda a: a,
            acc,
        )

        _, _, state, buf = self(state, buf)

        cnt += is_comp

        return (state, buf, acc, cnt)

    init_carry = (init_state, buf, acc, 0)

    out_carry = eqx.internal.while_loop(
        cond_fn,
        body_fn,
        init_carry,
        max_steps=self.max_event_steps,
        kind="bounded",
    )

    out_carry = eqx.error_if(
        out_carry, cond_fn(out_carry), "Reached max event steps. Try to increase event_steps."
    )

    _, _, filled, _ = out_carry
    # transpose to (n_outputs, n_times, obs_dim)
    return filled.transpose((1, 0, 2))

Returns the observable state of output neurons at the specified computation times. Uses pseudospike events to halt integration at the requested times. Returns an array of shape (n_outputs, n_times, obs_channels).


record

Source code in eventax/evnn.py
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
def record(self, in_spike_times: Float[Array, "in_size K"]):

    in_spike_times = in_spike_times.astype(self.dtype)

    init_state = self.init_state()
    buf0 = self.init_buffer(in_spike_times)
    idx_dtype = buf0.internal_state.from_indices.dtype
    no_id = buf0.index_non_spike_value

    def cond_fn(carry):
        t, m, state, buf, rec_t, rec_id, rec_buf, step = carry
        not_empty = jnp.logical_not(self.spike_buffer.is_empty(buf))
        steps_ok = step < self.max_event_steps
        return jnp.logical_and(not_empty, steps_ok)

    def body_fn(carry):
        t, m, state, buf, rec_t, rec_id, rec_buf, step = carry

        t_new, m_new, state_new, buf_new = self(state, buf)

        n_spikes_any = jnp.sum(m_new.astype(jnp.int32))
        did_spike = n_spikes_any == 1

        spike_id = jnp.argmax(m_new).astype(idx_dtype)

        rec_t = rec_t.at[step].set(jnp.where(did_spike, t_new, rec_t[step]))
        rec_id = rec_id.at[step].set(jnp.where(did_spike, spike_id, rec_id[step]))

        buf_size = jnp.asarray(self.spike_buffer.size(buf_new), dtype=jnp.int32)
        rec_buf = rec_buf.at[step].set(buf_size)

        return (t_new, m_new, state_new, buf_new, rec_t, rec_id, rec_buf, step + 1)

    init_carry = (
        jnp.array(0.0, dtype=self.dtype),
        jnp.zeros((self.n_neurons,), dtype=jnp.bool_),
        init_state,
        buf0,
        jnp.full((self.max_event_steps,), self.output_no_spike_value,
                 dtype=self.dtype),
        jnp.full((self.max_event_steps,), no_id, dtype=idx_dtype),
        jnp.full((self.max_event_steps,), -1, dtype=jnp.int32),
        jnp.array(0, dtype=jnp.int32),
    )

    out = eqx.internal.while_loop(
        cond_fn, body_fn, init_carry, max_steps=self.max_event_steps, kind="bounded"
    )

    _, _, _, _, recorded_spike_times, recorded_spike_ids, recorded_buffer_sizes, _ = out
    return recorded_spike_times, recorded_spike_ids, recorded_buffer_sizes

Runs the full event loop up to max_event_steps, recording spike times, neuron IDs, and buffer sizes at each step. Useful for debugging and visualisation.

Returns a tuple (spike_times, spike_ids, buffer_sizes), each of shape (max_event_steps,).


get_wmask

Source code in eventax/evnn.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
def get_wmask(self) -> Float[Array, "in_plus_neurons neurons"]:

    wmask = jnp.zeros((self.n_neurons + self.in_size, self.n_neurons), dtype=self.dtype)
    _, no_ids_value = self.spike_buffer.calc_dtype_and_non_spike_value(
        self.n_neurons + self.in_size
    )

    # Handle all neurons (including input neurons)
    valid = self.syn_conn != no_ids_value
    cols = jnp.where(valid, self.syn_conn, 0)
    rows = jnp.broadcast_to(
        jnp.arange(self.n_neurons + self.in_size)[:, None],
        self.syn_conn.shape,
    )

    rows_flat = rows[valid]
    cols_flat = cols[valid]
    wmask = wmask.at[rows_flat, cols_flat].set(1)

    return wmask

Reconstructs a dense \(\{0, 1\}\) connectivity mask of shape \((N+K) \times N\) from the internal syn_conn representation.