Skip to content

Simulation

collimator.simulation

ODESolverOptions dataclass

Options for the ODE solver.

See documentation for simulate for details on these options.

Source code in collimator/backend/ode_solver.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclasses.dataclass
class ODESolverOptions:
    """Options for the ODE solver.

    See documentation for `simulate` for details on these options.
    """

    rtol: float = 1e-3
    atol: float = 1e-6
    min_step_size: float = None
    max_step_size: float = None
    method: str = "auto"  # Dopri5 (jax/scipy) or BDF (jax)
    enable_autodiff: bool = False
    max_checkpoints: int = None  # Only used for checkpointing in autodiff

Simulator

Class for orchestrating simulations of hybrid dynamical systems.

See the simulate function for more details.

Source code in collimator/simulation/simulator.py
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
class Simulator:
    """Class for orchestrating simulations of hybrid dynamical systems.

    See the `simulate` function for more details.
    """

    def __init__(
        self,
        system: SystemBase,
        ode_solver: ODESolverBase = None,
        options: SimulatorOptions = None,
    ):
        """Initialize the simulator.

        Args:
            system (SystemBase): The hybrid dynamical system to simulate.
            ode_solver (ODESolverBase):
                The ODE solver to use for integrating the continuous-time component
                of the system.  If not provided, a default solver will be used.
            options (SimulatorOptions):
                Options for the simulation process.  See `simulate` for details.
        """
        self.system = system

        if options is None:
            options = SimulatorOptions()

        # Determine whether JAX tracing can be used (jit, grad, vmap, etc)
        math_backend, self.enable_tracing = _check_backend(options)

        # Set the math backend
        cnp.set_backend(math_backend)

        # Should the simulation be run with autodiff enabled?  This will override
        # the `advance_to` method with a custom autodiff rule.
        self.enable_autodiff = options.enable_autodiff

        if ode_solver is None:
            ode_solver = ODESolver(system, options=options.ode_options)

        # Store configuration options
        self.buffer_length = options.buffer_length
        self.max_major_steps = options.max_major_steps
        self.max_major_step_length = options.max_major_step_length
        self.save_time_series = options.save_time_series
        self.recorded_outputs = options.recorded_signals
        self.zc_bisection_loop_count = options.zc_bisection_loop_count
        self.major_step_callback = options.major_step_callback

        if self.max_major_step_length is None:
            self.max_major_step_length = np.inf

        logger.debug("Simulator created with enable_tracing=%s", self.enable_tracing)

        self.ode_solver = ode_solver

        # Modify the default autodiff rule slightly to correctly capture variations
        # in end time of the simulation interval.
        self.has_terminal_events = system.zero_crossing_events.has_terminal_events
        self.advance_to = self._override_advance_to_vjp()

        # Also override the guarded ODE integration with a custom autodiff rule
        # to capture variations due to zero-crossing time.
        self.guarded_integrate = self._override_guarded_integrate_vjp()

    def while_loop(self, cond_fun, body_fun, val):
        """Structured control flow primitive for a while loop.

        Dispatches to a bounded while loop as necessary for autodiff. Otherwise
        it will call the "backend" implementation, which will either be
        `lax.while_loop` when JAX is the backend, or a pure-Python implementation
        when using NumPy.
        """
        # If autodiff is enabled, we need to use a custom while loop with a maximum
        # number of steps so that the loop is reverse-mode differentiable.
        # Otherwise we can use a standard unbounded while loop with lax backend.
        if self.enable_autodiff:
            return _bounded_while_loop(cond_fun, body_fun, val, self.max_major_steps)
        else:
            return backend.while_loop(cond_fun, body_fun, val)

    def initialize(self, context: ContextBase) -> SimulatorState:
        """Perform initial setup for the simulation."""
        logger.debug("Initializing simulator")
        # context.state.pprint(logger.debug)

        # Initial simulation time as integer (picoseconds)
        initial_int_time = IntegerTime.from_decimal(context.time)

        # Ensure that _next_update_time() can return the current time by perturbing
        # current time as slightly toward negative infinity as possible
        time_of_next_timed_event, timed_events = _next_update_time(
            self.system.periodic_events, initial_int_time - 1
        )

        # timed_events is now marked with the active events at the next update time
        logger.debug("Time of next timed event (int): %s", time_of_next_timed_event)
        logger.debug(
            "Time of next event (sec): %s",
            IntegerTime.as_decimal(time_of_next_timed_event),
        )
        timed_events.pprint(logger.debug)

        end_reason = cnp.where(
            time_of_next_timed_event == initial_int_time,
            StepEndReason.TimeTriggered,
            StepEndReason.NothingTriggered,
        )

        # Initialize the results data that will hold recorded time series data.
        if self.save_time_series:
            results_data = ResultsData.initialize(
                context, self.recorded_outputs, self.buffer_length
            )
        else:
            results_data = None

        return SimulatorState(
            context=context,
            timed_events=timed_events,
            step_end_reason=end_reason,
            int_time=initial_int_time,
            results_data=results_data,
            ode_solver_state=self.ode_solver.initialize(context),
        )

    def save_results(
        self, results_data: ResultsData, context: ContextBase
    ) -> ResultsData:
        """Update the results data with the current state of the system."""
        if not self.save_time_series:
            return results_data
        return results_data.update(context)

    def _override_advance_to_vjp(self) -> Callable:
        """Construct the `advance_to` method for the simulator.

        See the docstring for `Simulator._advance_to` for details.

        If JAX tracing is enabled for autodiff, wrap the advance function with a
        custom autodiff rule to correctly capture variation with respect to end
        time. If somehow autodiff works with tracing disabled, the derivatives will
        not account for possible variations in end time (for instance in finding
        limit cycles or when there are terminal conditions on the simulation).
        """
        if not self.enable_autodiff:
            return self._advance_to

        # This is the main function call whose autodiff rule will be overridden.
        def _wrapped_advance_to(
            sim: Simulator, boundary_time: float, context: ContextBase
        ) -> SimulatorState:
            return sim._advance_to(boundary_time, context)

        # The "forwards pass" to advance the simulation.  Also stores the nominal
        # VJP calculation and the continuous time derivative value, both of which
        # will be needed in the backwards pass.
        def _wrapped_advance_to_fwd(
            sim: Simulator, boundary_time: float, context: ContextBase
        ) -> tuple[SimulatorState, tuple]:
            primals, vjp_fun = jax.vjp(sim._advance_to, boundary_time, context)

            # Also need to keep the final continuous time derivative value for
            # computing the adjoint time variable
            xdot = sim.system.eval_time_derivatives(primals.context)

            # "Residual" information needed in the backwards pass
            res = (vjp_fun, xdot, primals.step_end_reason)

            return primals, res

        def _wrapped_advance_to_adj(
            _sim: Simulator, res: tuple, adjoints: SimulatorState
        ) -> tuple[float, ContextBase]:
            # Unpack the residuals from the forward pass
            vjp_fun, xdot, reason = res

            # Compute whatever the standard adjoint variables are using the
            # automatically derived VJP function.  The first return variable will
            # be the automatically computed tf_adj value, which we will ignore in
            # favor of the manually derived value computed next.
            _, context_adj = vjp_fun(adjoints)

            # The derivative with respect to end time is just the dot product of
            # the adjoint "seed" continuous state with the final time derivatives.
            # We can overwrite whatever the calculated adjoint time was with this.
            vc = adjoints.context.continuous_state
            vT_xdot = jax.tree_util.tree_map(
                lambda xdot, vc: jnp.dot(xdot, vc), xdot, vc
            )

            # On the other hand, if the simulation ends early due to a terminal
            # event, then the derivative with respect to end time is zero.
            tf_adj = jnp.where(
                reason == StepEndReason.TerminalEventTriggered,
                0.0,
                sum(jax.tree_util.tree_leaves(vT_xdot)),
            )

            # Return adjoints to match the inputs to _wrapped_advance_to, except for
            # the first argument (Simulator), which will be marked nondifferentiable.
            return (tf_adj, context_adj)

        advance_to = jax.custom_vjp(_wrapped_advance_to, nondiff_argnums=(0,))
        advance_to.defvjp(_wrapped_advance_to_fwd, _wrapped_advance_to_adj)

        # Copy the docstring and type hints to the overridden function
        advance_to.__doc__ = self._advance_to.__doc__
        advance_to.__annotations__ = self._advance_to.__annotations__

        return partial(advance_to, self)

    def _guarded_integrate(
        self,
        solver_state: ODESolverState,
        results_data: ResultsData,
        tf: float,
        context: ContextBase,
        zc_events: EventCollection,
    ) -> tuple[bool, ODESolverState, ContextBase, ResultsData, EventCollection]:
        """Guarded ODE integration.

        Advance continuous time using an ODE solver, localizing any zero-crossing events
        that occur during the requested interval.  If any zero-crossing events trigger,
        the dense interpolant is used to localize the events and the associated reset maps
        are handled.  The method then returns, guaranteeing that the major step terminates
        either at the end of the requested interval or at the time of a zero-crossing
        event.

        Args:
            solver_state (ODESolverState): The current state of the ODE solver.
            results_data (ResultsData): The results data that will hold recorded time
                series data.
            tf (float): The end time of the integration interval.
            context (ContextBase): The current state of the system.
            zc_events (EventCollection): The current zero-crossing events.

        Returns:
            tuple[bool, ODESolverState, ContextBase, ResultsData, EventCollection]:
                A tuple containing the following:
                - A boolean indicating whether the major step was terminated early due to
                  a zero-crossing event.
                - The updated state of the ODE solver.
                - The updated state of the system.
                - The updated results data.
                - The updated zero-crossing events.
        """
        solver = self.ode_solver
        func = solver.flat_ode_rhs  # Raveled ODE RHS function

        # Close over the additional arguments so that the RHS function has the
        # signature `func(y, t)`.
        def _func(y, t):
            return func(y, t, context)

        def _localize_zc_minor(
            solver_state, context_t0, context_tf, zc_events, results_data
        ):
            # Using the ODE solver interpolant, employ bisection to find a 'small' time
            # interval within which the earliest zero crossing occurrs. See
            # _bisection_step_fun for details about how bisection is employed for
            # localizing the zero crossing in time.
            int_t1 = IntegerTime.from_decimal(context_tf.time)
            int_t0 = IntegerTime.from_decimal(context_t0.time)
            _body_fun = partial(_bisection_step_fun, solver_state)
            carry = GuardIsolationData(int_t0, int_t1, zc_events, context_tf)
            search_data = backend.fori_loop(
                0, self.zc_bisection_loop_count, _body_fun, carry
            )
            context_tf = search_data.context
            zc_events = search_data.guards

            # record results sample for the ZC having 'occurred'
            minor_step_end_time = IntegerTime.as_decimal(int_t1)
            minor_step_start_time = IntegerTime.as_decimal(int_t0)
            zc_occur_time = context_tf.time - (
                minor_step_end_time - minor_step_start_time
            ) / (2 ** (self.zc_bisection_loop_count + 1))
            context_zc_time = context_tf.with_time(zc_occur_time)
            results_data = self.save_results(results_data, context_zc_time)

            # Handle any triggered zero-crossing events
            context_tf = self.system.handle_zero_crossings(zc_events, context_tf)

            # Re-initialize the solver, since the state may have been reset
            # Keep the last step size, since there's no reason to assume that the
            # dynamics have changed significantly as a result of the event.  If that is
            # the case, then we're relying on the adaptive stepping to re-calibrate.
            #
            # NOTE: this previously only updated the state and time of
            # the solver state, but with multistep solvers (e.g. BDF), the
            # solver needs to be fully reinitialized because the history
            # of differences needs to be cleared and rebuilt over the next few steps.
            solver_state = solver.initialize(context_tf)
            return solver_state, context_tf, zc_events, results_data

        def _no_events_fun(
            solver_state, context_t0, context_tf, zc_events, results_data
        ):
            return solver_state, context_tf, zc_events, results_data

        def _ode_step(carry):
            _, solver_state, context_t0, results_data, zc_events = carry

            # Save results at the top of the loop. This will save data at t=t0,
            # but not at t=tf.  This is okay, since we will save the results at
            # the top of the next major step, as well as at the end of the main
            # simulation loop.
            results_data = self.save_results(results_data, context_t0)

            zc_events = guard_interval_start(zc_events, context_t0)

            # Advance ODE solver
            solver_state = solver.step(_func, tf, solver_state)
            xc = solver_state.unraveled_state
            context = context_t0.with_time(solver_state.t).with_continuous_state(xc)

            # Check for zero-crossing events
            zc_events = determine_triggered_guards(zc_events, context)

            triggered = zc_events.has_triggered

            args = (solver_state, context_t0, context, zc_events, results_data)
            solver_state, context, zc_events, results_data = backend.cond(
                triggered, _localize_zc_minor, _no_events_fun, *args
            )

            return (triggered, solver_state, context, results_data, zc_events)

        def _cond_fun(carry):
            triggered, solver_state, _, _, _ = carry
            return (solver_state.t < tf) & (~triggered)

        carry = (False, solver_state, context, results_data, zc_events)
        triggered, solver_state, context, results_data, zc_events = backend.while_loop(
            _cond_fun,
            _ode_step,
            carry,
        )

        return triggered, solver_state, context, results_data, zc_events

    def _override_guarded_integrate_vjp(self):
        if not self.enable_autodiff:
            return self._guarded_integrate

        def _wrapped_solve(
            self: Simulator, solver_state, results_data, tf, context, zc_events
        ):
            return self._guarded_integrate(
                solver_state, results_data, tf, context, zc_events
            )

        def _wrapped_solve_fwd(
            self: Simulator, solver_state, _results_data, tf, context, zc_events
        ):
            # Run the forward major step as usual (primal calculation). Do not save
            # any results here

            # The return from the forward call has the state post-reset, including the
            # solver state, the context, and the zero-crossing events.
            t0 = solver_state.t

            (
                triggered,
                solver_state_out,
                context_out,
                _,
                zc_events_out,
            ) = self._guarded_integrate(solver_state, None, tf, context, zc_events)
            tf = solver_state_out.t

            # Define a differentiable function for the forward pass, knowing where the
            # zero-crossing occurs. Note that `tf` here should be the _actual_ interval
            # end time, not the requested end time.
            solver = self.ode_solver
            func = solver.flat_ode_rhs  # Raveled ODE RHS function

            def _forward(solver_state, tf, context, zc_events_out):
                solver_state_out = _odeint(solver, func, solver_state, tf, context)
                context = context.with_time(solver_state_out.t)
                context = context.with_continuous_state(
                    solver_state_out.unraveled_state
                )
                context = self.system.handle_zero_crossings(zc_events_out, context)
                return context

            # Get the VJP for the event handling
            _primals, vjp_fun = jax.vjp(
                _forward, solver_state, tf, context, zc_events_out
            )

            primals = (triggered, solver_state_out, context_out, None, zc_events_out)
            residuals = (triggered, solver_state_out, t0, tf, context, vjp_fun)
            return primals, residuals

        def _wrapped_solve_adj(self: Simulator, residuals, adjoints):
            triggered, primal_solver_state, t0, tf, context, vjp_fun = residuals
            (
                _triggered_adj,
                solver_state_adj,
                context_adj,
                _results_data_adj,
                _zc_events_adj,
            ) = adjoints

            context_adj = context_adj.with_time(solver_state_adj.t)
            context_adj = context_adj.with_continuous_state(
                solver_state_adj.unraveled_state
            )

            # The `_forward` function corresponding to `vjp_fun` has the signature
            # `context_out = _forward(solver_state, tf, context, zc_events_out)`.
            # For the adjoint, we have to call with `context_adj` as the input:
            solver_state_adj, tf_adj, context_adj, zc_events_adj = vjp_fun(context_adj)

            # The Jacobian with respect to the final time is just the time derivative of
            # the state at the final time.
            yf = primal_solver_state.y
            yf_bar = solver_state_adj.y
            func = self.ode_solver.flat_ode_rhs  # Raveled ODE RHS function
            tf_adj = jnp.dot(func(yf, tf, context), yf_bar)

            tf_adj = jnp.where(
                triggered,
                tf_adj,
                jnp.zeros_like(tf_adj),
            )

            return (solver_state_adj, None, tf_adj, context_adj, zc_events_adj)

        guarded_integrate = jax.custom_vjp(_wrapped_solve, nondiff_argnums=(0,))
        guarded_integrate.defvjp(_wrapped_solve_fwd, _wrapped_solve_adj)

        # Copy the docstring and type hints to the overridden function
        guarded_integrate.__doc__ = self._guarded_integrate.__doc__
        guarded_integrate.__annotations__ = self._guarded_integrate.__annotations__

        return partial(guarded_integrate, self)

    def _advance_continuous_time(
        self,
        cdata: ContinuousIntervalData,
    ) -> ContinuousIntervalData:
        """Advance the simulation to the next discrete update or zero-crossing event.

        This stores the values of all active guard functions and advances the
        continuous-time component of the system to the next discrete update or
        zero-crossing event, whichever comes first.  Zero-crossing events are
        localized using a bisection search defined by `_trigger_search`, which will
        also record the final guard function values at the end of the search interval
        and determine which (if any) zero-crossing events were triggered.
        """

        # Unpack inputs
        int_tf = cdata.tf
        context = cdata.context
        results_data = cdata.results_data

        zc_events = self.system.determine_active_guards(context)

        if self.system.has_continuous_state:
            solver_state = cdata.ode_solver_state
            tf = IntegerTime.as_decimal(int_tf)

            (
                triggered,
                solver_state,
                context,
                results_data,
                zc_events,
            ) = self.guarded_integrate(
                solver_state,
                results_data,
                tf,
                context,
                zc_events,
            )

            context = context.with_time(solver_state.t)
            context = context.with_continuous_state(solver_state.unraveled_state)

            # Converting from decimal -> integer time incurs a loss of precision.  This is
            # okay for unscheduled zero-crossing events, but problematic for timed events.
            # So only do this conversion if a zero-crossing was triggered.  Otherwise we
            # know we have reached the end of the interval and can keep the requested end
            # time.
            int_tf = cnp.where(
                triggered,
                IntegerTime.from_decimal(context.time),
                int_tf,
            )

        else:
            # Skip the ODE solver for systems without continuous state.  We still
            # have to check for triggered events here in case there are any
            # transitions triggered by time that need to be handled before the
            # periodic discrete update at the top of the next major step
            triggered = False
            solver_state = cdata.ode_solver_state

            zc_events = guard_interval_start(zc_events, context)
            results_data = self.save_results(results_data, context)

            # Advance time to the end of the interval
            context = context.with_time(IntegerTime.as_decimal(int_tf))

            # Record guard values after the discrete update and check if anything
            # triggered as a result of advancing time
            zc_events = guard_interval_end(zc_events, context)
            zc_events = determine_triggered_guards(zc_events, context)

            # Handle any triggered zero-crossing events
            context = self.system.handle_zero_crossings(zc_events, context)

        # Even though the zero-crossing events have already been "handled", the
        # information about whether a terminal event has been triggered is still in
        # the events collection (since "triggered" has not been cleared by a call
        # to determine_triggered_guards).
        terminate_early = zc_events.has_active_terminal

        return cdata._replace(
            triggered=triggered,
            terminate_early=terminate_early,
            context=context,
            tf=int_tf,
            results_data=results_data,
            ode_solver_state=solver_state,
        )

    def _handle_discrete_update(
        self, context: ContextBase, timed_events: EventCollection
    ) -> tuple[ContextBase, bool]:
        """Handle discrete updates triggered by time.

        This method is called at the beginning of each major step to handle any
        discrete updates that are triggered by time.  This includes both discrete
        updates that are triggered by time and any zero-crossing events that are
        triggered by the discrete update.

        This will also work when there are no zero-crossing events: the zero-crossing
        collection will be empty and only the periodic discrete update will happen.

        Args:
            context (ContextBase): The current state of the system.
            timed_events (EventCollection):
                The collection of timed events, with the active events marked.

        Returns:
            ContextBase: The updated state of the system.
            bool: Whether the simulation should terminate early as a result of a
                triggered terminal condition.
        """
        system = self.system

        # Get the collection of zero-crossing events that _might_ be activated
        # given the current state of the system.  For example, some events may
        # be de-activated as a result of the current state of a state machine.
        zc_events = system.determine_active_guards(context)

        # Record guard values at the start of the interval
        zc_events = guard_interval_start(zc_events, context)

        # Handle any active periodic discrete updates
        context = system.handle_discrete_update(timed_events, context)

        # Record guard values after the discrete update
        zc_events = guard_interval_end(zc_events, context)

        # Check if guards have triggered as a result of these updates
        zc_events = determine_triggered_guards(zc_events, context)
        terminate_early = zc_events.has_active_terminal

        # Handle any zero-crossing events that were triggered
        context = system.handle_zero_crossings(zc_events, context)

        return context, terminate_early

    # This method is marked private because it will be wrapped with a custom autodiff
    # rule to get the correct derivatives with respect to the end time of the
    # simulation interval using `_override_advance_to_vjp`.  This also copies the
    # docstring to the overridden function. Normally the wrapped attribute `advance_to`
    # is what should be called by users.
    def _advance_to(self, boundary_time: float, context: ContextBase) -> SimulatorState:
        """Core control flow logic for running a simulation.

        This is the main loop for advancing the simulation.  It is called by `simulate`
        or can be called directly if more fine-grained control is needed. This method
        essentially loops over "major steps" until the boundary time is reached. See
        the documentation for `simulate` for details on the order of operations in a
        major step.

        Args:
            boundary_time (float): The time to advance to.
            context (ContextBase): The current state of the system.

        Returns:
            SimulatorState:
                A named tuple containing the final state of the simulation, including
                the final context, a collection of pending timed events, and a flag
                indicating the reason that the most recent major step ended.

        Notes:
            API will change slightly as a result of WC-87, which will break out the
            initialization from the main loop so that `advance_to` can be called
            repeatedly.  See:
            https://collimator.atlassian.net/browse/WC-87
        """

        system = self.system
        sim_state = self.initialize(context)
        end_reason = sim_state.step_end_reason
        context = sim_state.context
        timed_events = sim_state.timed_events
        int_boundary_time = IntegerTime.from_decimal(boundary_time)

        # We will be limiting each step by the max_major_step_length.  However, if this
        # is infinite we should just use the end time of the simulation to avoid
        # integer overflow.  This could be problematic if the end time of the
        # simulation is close to the maximum representable integer time, but we can come
        # back to that if it's an issue.
        int_max_step_length = IntegerTime.from_decimal(
            cnp.minimum(self.max_major_step_length, boundary_time)
        )

        # Only activate timed events if the major step ended on a time trigger
        timed_events = activate_timed_events(timed_events, end_reason)

        # Called on the "True" branch of the conditional
        def _major_step(sim_state: SimulatorState) -> SimulatorState:
            end_reason = sim_state.step_end_reason
            context = sim_state.context
            timed_events = sim_state.timed_events
            int_time = sim_state.int_time

            if not self.enable_tracing:
                logger.debug("Starting a simulation step at t=%s", context.time)
                logger.debug("   merged_events: %s", timed_events)

            # Handle any discrete updates that are triggered by time along with
            # any zero-crossing events that are triggered by the discrete update.
            context, terminate_early = self._handle_discrete_update(
                context, timed_events
            )
            logger.debug("Terminate early after discrete update: %s", terminate_early)

            # How far can we go before we have to handle timed events?
            # The time returned here will be the integer time representation.
            time_of_next_timed_event, timed_events = _next_update_time(
                system.periodic_events, int_time
            )
            if not self.enable_tracing:
                logger.debug(
                    "Next timed event at t=%s",
                    IntegerTime.as_decimal(time_of_next_timed_event),
                )
                timed_events.pprint(logger.debug)

            # Determine whether the events include a timed update
            update_time = IntegerTime.max_int_time

            if timed_events.num_events > 0:
                update_time = time_of_next_timed_event

            # Limit the major step end time to the simulation end time, major step limit,
            # or next periodic update time.
            # This is the mechanism used to advance time for systems that have
            # no states and no periodic events.
            # Discrete systems] when there are discrete periodic events, we use those
            # to determine each major step end time.
            # Feedthrough system] when there are just feedthrough blocks (no states or
            # events), use max_major_step_length to determine each major step end time.
            int_tf_limit = int_time + int_max_step_length
            int_tf = cnp.min(
                cnp.array(
                    [
                        int_boundary_time,
                        int_tf_limit,
                        update_time,
                    ]
                )
            )
            if not self.enable_tracing:
                logger.debug(
                    "Expecting to integrate to t=%s",
                    IntegerTime.as_decimal(int_tf),
                )

            # Normally we will advance continuous time to the end of the major step
            # here. However, if a terminal event was triggered as part of the discrete
            # update, we should respect that and skip the continuous update.
            #
            # Construct the container used to hold various data related to advancing
            # continuous time.  This is passed to ODE solvers, zero-crossing
            # localization, and related functions.
            cdata = ContinuousIntervalData(
                context=context,
                terminate_early=terminate_early,
                triggered=False,
                t0=int_time,
                tf=int_tf,
                results_data=sim_state.results_data,
                ode_solver_state=sim_state.ode_solver_state,
            )
            cdata = backend.cond(
                (self.has_terminal_events & cdata.terminate_early),
                lambda cdata: cdata,  # Terminal event triggered - return immediately
                self._advance_continuous_time,  # Advance continuous time normally
                cdata,
            )

            # Unpack the results of the continuous time advance
            context = cdata.context
            terminate_early = cdata.terminate_early
            triggered = cdata.triggered
            int_tf = cdata.tf
            results_data = cdata.results_data
            ode_solver_state = cdata.ode_solver_state

            # Determine the reason why the major step ended.  Did a zero-crossing
            # trigger, did a timed event trigger, neither, or both?
            # terminate_early = terminate_early | zc_events.has_active_terminal
            logger.debug("Terminate early after major step: %s", terminate_early)
            end_reason = _determine_step_end_reason(
                triggered, terminate_early, int_tf, update_time
            )
            logger.debug("Major step end reason: %s", end_reason)

            # Conditionally activate timed events depending on whether the major step
            # ended as a result of a time trigger or zero-crossing event.
            timed_events = activate_timed_events(timed_events, end_reason)

            if self.major_step_callback:
                io_callback(self.major_step_callback, (), context.time)

            return SimulatorState(
                step_end_reason=end_reason,
                context=context,
                timed_events=timed_events,
                int_time=int_tf,
                results_data=results_data,
                ode_solver_state=ode_solver_state,
            )

        def _cond_fun(sim_state: SimulatorState):
            return (sim_state.int_time < int_boundary_time) & (
                sim_state.step_end_reason != StepEndReason.TerminalEventTriggered
            )

        # Initialize the "carry" values for the main loop.
        sim_state = SimulatorState(
            context=context,
            timed_events=timed_events,
            step_end_reason=end_reason,
            int_time=sim_state.int_time,
            results_data=sim_state.results_data,
            ode_solver_state=sim_state.ode_solver_state,
        )

        logger.debug(
            "Running simulation from t=%s to t=%s", context.time, boundary_time
        )

        try:
            # Main loop call
            sim_state = self.while_loop(_cond_fun, _major_step, sim_state)
            logger.debug("Simulation complete at t=%s", sim_state.context.time)
        except KeyboardInterrupt:
            # TODO: flag simulation as interrupted somewhere in sim_state
            logger.info("Simulation interrupted at t=%s", sim_state.context.time)

        # At the end of the simulation we need to handle any pending discrete updates
        # and store the solution one last time.
        # FIXME (WC-87): The returned simulator state can't be used with advance_to again,
        # since the discrete updates have already been performed. Should be broken out
        # into a `finalize` method as part of WC-87.

        # update discrete state to x+ at the simulation end_time
        if self.save_time_series:
            logger.debug("Finalizing solution...")
            # 1] do discrete update (will skip if the simulation was terminated early)
            context, _terminate_early = self._handle_discrete_update(
                sim_state.context, sim_state.timed_events
            )
            # 2] do update solution
            results_data = self.save_results(sim_state.results_data, context)
            sim_state = sim_state._replace(
                context=context,
                results_data=results_data,
            )
            logger.debug("Done finalizing solution")

        return sim_state

__init__(system, ode_solver=None, options=None)

Initialize the simulator.

Parameters:

Name Type Description Default
system SystemBase

The hybrid dynamical system to simulate.

required
ode_solver ODESolverBase

The ODE solver to use for integrating the continuous-time component of the system. If not provided, a default solver will be used.

None
options SimulatorOptions

Options for the simulation process. See simulate for details.

None
Source code in collimator/simulation/simulator.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def __init__(
    self,
    system: SystemBase,
    ode_solver: ODESolverBase = None,
    options: SimulatorOptions = None,
):
    """Initialize the simulator.

    Args:
        system (SystemBase): The hybrid dynamical system to simulate.
        ode_solver (ODESolverBase):
            The ODE solver to use for integrating the continuous-time component
            of the system.  If not provided, a default solver will be used.
        options (SimulatorOptions):
            Options for the simulation process.  See `simulate` for details.
    """
    self.system = system

    if options is None:
        options = SimulatorOptions()

    # Determine whether JAX tracing can be used (jit, grad, vmap, etc)
    math_backend, self.enable_tracing = _check_backend(options)

    # Set the math backend
    cnp.set_backend(math_backend)

    # Should the simulation be run with autodiff enabled?  This will override
    # the `advance_to` method with a custom autodiff rule.
    self.enable_autodiff = options.enable_autodiff

    if ode_solver is None:
        ode_solver = ODESolver(system, options=options.ode_options)

    # Store configuration options
    self.buffer_length = options.buffer_length
    self.max_major_steps = options.max_major_steps
    self.max_major_step_length = options.max_major_step_length
    self.save_time_series = options.save_time_series
    self.recorded_outputs = options.recorded_signals
    self.zc_bisection_loop_count = options.zc_bisection_loop_count
    self.major_step_callback = options.major_step_callback

    if self.max_major_step_length is None:
        self.max_major_step_length = np.inf

    logger.debug("Simulator created with enable_tracing=%s", self.enable_tracing)

    self.ode_solver = ode_solver

    # Modify the default autodiff rule slightly to correctly capture variations
    # in end time of the simulation interval.
    self.has_terminal_events = system.zero_crossing_events.has_terminal_events
    self.advance_to = self._override_advance_to_vjp()

    # Also override the guarded ODE integration with a custom autodiff rule
    # to capture variations due to zero-crossing time.
    self.guarded_integrate = self._override_guarded_integrate_vjp()

initialize(context)

Perform initial setup for the simulation.

Source code in collimator/simulation/simulator.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def initialize(self, context: ContextBase) -> SimulatorState:
    """Perform initial setup for the simulation."""
    logger.debug("Initializing simulator")
    # context.state.pprint(logger.debug)

    # Initial simulation time as integer (picoseconds)
    initial_int_time = IntegerTime.from_decimal(context.time)

    # Ensure that _next_update_time() can return the current time by perturbing
    # current time as slightly toward negative infinity as possible
    time_of_next_timed_event, timed_events = _next_update_time(
        self.system.periodic_events, initial_int_time - 1
    )

    # timed_events is now marked with the active events at the next update time
    logger.debug("Time of next timed event (int): %s", time_of_next_timed_event)
    logger.debug(
        "Time of next event (sec): %s",
        IntegerTime.as_decimal(time_of_next_timed_event),
    )
    timed_events.pprint(logger.debug)

    end_reason = cnp.where(
        time_of_next_timed_event == initial_int_time,
        StepEndReason.TimeTriggered,
        StepEndReason.NothingTriggered,
    )

    # Initialize the results data that will hold recorded time series data.
    if self.save_time_series:
        results_data = ResultsData.initialize(
            context, self.recorded_outputs, self.buffer_length
        )
    else:
        results_data = None

    return SimulatorState(
        context=context,
        timed_events=timed_events,
        step_end_reason=end_reason,
        int_time=initial_int_time,
        results_data=results_data,
        ode_solver_state=self.ode_solver.initialize(context),
    )

save_results(results_data, context)

Update the results data with the current state of the system.

Source code in collimator/simulation/simulator.py
577
578
579
580
581
582
583
def save_results(
    self, results_data: ResultsData, context: ContextBase
) -> ResultsData:
    """Update the results data with the current state of the system."""
    if not self.save_time_series:
        return results_data
    return results_data.update(context)

while_loop(cond_fun, body_fun, val)

Structured control flow primitive for a while loop.

Dispatches to a bounded while loop as necessary for autodiff. Otherwise it will call the "backend" implementation, which will either be lax.while_loop when JAX is the backend, or a pure-Python implementation when using NumPy.

Source code in collimator/simulation/simulator.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
def while_loop(self, cond_fun, body_fun, val):
    """Structured control flow primitive for a while loop.

    Dispatches to a bounded while loop as necessary for autodiff. Otherwise
    it will call the "backend" implementation, which will either be
    `lax.while_loop` when JAX is the backend, or a pure-Python implementation
    when using NumPy.
    """
    # If autodiff is enabled, we need to use a custom while loop with a maximum
    # number of steps so that the loop is reverse-mode differentiable.
    # Otherwise we can use a standard unbounded while loop with lax backend.
    if self.enable_autodiff:
        return _bounded_while_loop(cond_fun, body_fun, val, self.max_major_steps)
    else:
        return backend.while_loop(cond_fun, body_fun, val)

SimulatorOptions dataclass

Options for the hybrid simulator.

See documentation for simulate for details on these options. This also contains all configuration for the ODE solver as a subset of options so that multiple options classes don't need to be created separately.

Source code in collimator/simulation/types.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@dataclasses.dataclass
class SimulatorOptions:
    """Options for the hybrid simulator.

    See documentation for `simulate` for details on these options.
    This also contains all configuration for the ODE solver as a subset of options
    so that multiple options classes don't need to be created separately.
    """

    math_backend: str = dataclasses.field(
        default_factory=lambda: numpy_api.active_backend
    )
    enable_tracing: bool = True
    enable_autodiff: bool = False

    # If autodiff is enabled, max_major_steps must be set in order to bound the number
    # of iterations in the while loop.  When running a simulation using the `simulate`
    # function, this can typically be determined automatically based on the number of
    # periodic events in the system.  However, it should be specified manually in the
    # following cases:
    #   - When running a simulation by creating a `Simulator` object and calling the
    #     `advance_to` method directly. In this case the `Simulator` object does not
    #     attempt to automatically determine a bound on the number of major steps.
    #   - When autodiff is used to compute the sensitivity with respect to simulation
    #     end time, for example when computing periodic limit cycles. In this case the
    #     time variables passed to `estimate_max_major_steps` are JAX tracers and cannot
    #     be used to determine a fixed (static) bound on the number of major steps.
    #   - When the system has frequent zero-crossing events.  In this case the "safety
    #     factor" in the heuristic for estimating the number of major steps may be too
    #     small, underestimating the bound on the number of major steps.
    # In any case, `estimate_max_major_steps` can still be called statically ahead
    # of time to determine a reasonable value for `max_major_steps`, using for instance
    # a conservative bound on end time and safety factor.
    max_major_steps: int = None
    max_major_step_length: float = None

    # Length of the buffer for storing time series data.  When the buffer is full the
    # data will be dumped to NumPy arrays.  For best performance, set to a value that
    # can hold the entire simulation time series.  However, in most cases this should
    # not need to be modified by the user.
    buffer_length: int = 1000

    # ODE solver options
    ode_solver_method: str = "auto"  # Dopri5 (jax/scipy) or BDF (jax)
    rtol: float = 1e-6  # Relative tolerance for adaptive solvers
    atol: float = 1e-8  # Absolute tolerance for adaptive solvers
    min_minor_step_size: float = None
    max_minor_step_size: float = None

    # This is used to bound the number of "checkpoints" in the adjoint solver and
    # is used only when autodiff is enabled.  Increasing this may improve the
    # accuracy of the adjoint solver (especially over long integration times), but
    # will also increase memory usage.  Whether or not the resulting adjoint solve
    # is faster depends on the details of the problem, for instance on the number of
    # major steps and the ODE solver tolerance.  This can also be set to None to
    # disable checkpointing altogether.
    max_checkpoints: int = 16

    # This option determines whether the simulator saves any data.  If the
    # simulation is initiated from `simulate` this will be set automatically
    # depending on whether `recorded_signals` is provided.  Hence, this
    # should not need to be manually configured.
    # FIXME: remove this and use `recorded_signals` instead. There are usecases
    # where simulate() is not used and we use the Simulator's advance_to function
    # directly. In those cases, recorded_signals can be set while save_time_series
    # is False which is confusing.
    save_time_series: bool = False

    # Dictionary of ports (or other cache sources) for which the time series should
    # be recorded. Note that if the simulation is initiated from `simulate` and
    # `recorded_signals` is provided as a kwarg to `simulate`, anything set here
    # will be overridden.  Hence, this should not need to be manually configured.
    recorded_signals: dict[str, SystemCallback] = None

    # If the context is not needed for anything, opting to not return it can
    # speed up compilation times.  For instance, typical simulation calls from
    # the UI don't use the context for anything, so model_interface.py will
    # set `return_context=False` for performance.
    return_context: bool = True

    # Zero crossings are localized in time using the ODE solver interpolant,
    # which provides state values for any time value in the previous integration
    # time interval.
    # Bisection is used to search the time interval. Rather than run bisection
    # in a while loop until the time interval is _small_, bisection is run for
    # fixed number of iterations, as this results in localizing zero crossings in
    # time within some small fraction of the integrated time interval.
    # e.g. if the major step length is 1.0 second, and bisection is run for 40
    # loops, the zero crossing time tolerance is approx. 1e-12, a.k.a. picosecond.
    zc_bisection_loop_count: int = 40

    # Scale of integer time used for event synchronization.  The default value is
    # 1e-12, corresponding to picosecond resolution.  The maximum representable
    # time in this case is around 0.3 years.  If longer simulations are needed,
    # the time scale can be increased to 1e-9, 1e-6, etc.
    int_time_scale: float = None

    # Called at the end of each major step with the current time as an argument.
    major_step_callback: Callable[[Scalar]] = None

    @property
    def ode_options(self) -> ODESolverOptions:
        return ODESolverOptions(
            rtol=self.rtol,
            atol=self.atol,
            min_step_size=self.min_minor_step_size,
            max_step_size=self.max_minor_step_size,
            method=self.ode_solver_method,
            enable_autodiff=self.enable_autodiff,
            max_checkpoints=self.max_checkpoints,
        )

    def __repr__(self) -> str:
        return (
            f"SimulatorOptions("
            f"math_backend={self.math_backend}, "
            f"enable_tracing={self.enable_tracing}, "
            f"max_major_step_length={self.max_major_step_length}, "
            f"max_major_steps={self.max_major_steps}, "
            f"ode_solver_method={self.ode_solver_method}, "
            f"rtol={self.rtol}, "
            f"atol={self.atol}, "
            f"min_minor_step_size={self.min_minor_step_size}, "
            f"max_minor_step_size={self.max_minor_step_size}, "
            f"zc_bisection_loop_count={self.zc_bisection_loop_count}, "
            f"save_time_series={self.save_time_series}, "
            f"recorded_signals={len(self.recorded_signals or [])}, "  # changed
            f"return_context={self.return_context}"
            f")"
        )

estimate_max_major_steps(system, tspan, max_major_step_length=None, safety_factor=2)

Heuristic for estimating the required number of major steps.

This is used to bound the number of iterations in the while loop in the simulate function when automatic differentiation is enabled. The number of major steps is determined by the smallest discrete period in the system and the length of the simulation interval. The number of major steps is bounded by the length of the simulation interval divided by the smallest discrete period, with a safety factor applied. The safety factor accounts for unscheduled major steps that may be triggered by zero-crossing events.

This function assumes static time variables, so cannot be called from within traced (JAX-transformed) functions. This is typically the case when the beginning or end time of the simulation is a variable that will be differentiated. In this case estimate_max_major_steps should be called statically ahead of time to determine a reasonable bound for max_major_steps.

Parameters:

Name Type Description Default
system SystemBase

The system to simulate.

required
tspan tuple[float, float]

The time interval to simulate over.

required
max_major_step_length float

The maximum length of a major step. If provided, this will be used to bound the number of major steps. Otherwise it will be ignored.

None
safety_factor int

The safety factor to apply to the number of major steps. Defaults to 2.

2
Source code in collimator/simulation/simulator.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def estimate_max_major_steps(
    system: SystemBase,
    tspan: tuple[float, float],
    max_major_step_length: float = None,
    safety_factor: int = 2,
) -> int:
    """Heuristic for estimating the required number of major steps.

    This is used to bound the number of iterations in the while loop in the
    `simulate` function when automatic differentiation is enabled.  The number
    of major steps is determined by the smallest discrete period in the system
    and the length of the simulation interval.  The number of major steps is
    bounded by the length of the simulation interval divided by the smallest
    discrete period, with a safety factor applied.  The safety factor accounts
    for unscheduled major steps that may be triggered by zero-crossing events.

    This function assumes static time variables, so cannot be called from within
    traced (JAX-transformed) functions.  This is typically the case when the
    beginning or end time of the simulation is a variable that will be
    differentiated.  In this case `estimate_max_major_steps` should be called
    statically ahead of time to determine a reasonable bound for `max_major_steps`.

    Args:
        system (SystemBase): The system to simulate.
        tspan (tuple[float, float]): The time interval to simulate over.
        max_major_step_length (float, optional): The maximum length of a major
            step. If provided, this will be used to bound the number of major
            steps. Otherwise it will be ignored.
        safety_factor (int, optional): The safety factor to apply to the number of
            major steps.  Defaults to 2.
    """
    # For autodiff of collimator.simulate, this path is not possible, JAX
    # throws an error. To work around this, create:
    #   options = SimulatorOptions(max_major_steps=<my value>)
    # outside collimator.simulate, and pass in like this:
    #   collimator.simulate(my_model, options=options)

    # Find the smallest period amongst the periodic events of the system
    if system.periodic_events.has_events or max_major_step_length is not None:
        # Initialize to infinity - will be overwritten by at least one conditional
        min_discrete_step = np.inf

        # Bound the number of major steps based on the smallest discrete period in
        # the system.
        if system.periodic_events.has_events:
            event_periods = jax.tree_util.tree_map(
                lambda event_data: event_data.period,
                system.periodic_events,
                is_leaf=is_event_data,
            )
            min_discrete_step = jax.tree_util.tree_reduce(min, event_periods)

        # Also bound the number of major steps based on the max major step length
        # in case that is shorter than any of the update periods.
        if max_major_step_length is not None:
            min_discrete_step = min(min_discrete_step, max_major_step_length)

        # in this case, we assume that, on average, major steps triggered by
        # zero crossing event, will be as frequent or less frequent than major steps
        # triggered by the smallest discrete period.
        # anything less than 100 is considered inadequate. user can override if they want this.
        max_major_steps = max(100, safety_factor * int(tspan[1] // min_discrete_step))
        logger.info(
            "max_major_steps=%s based on smallest discrete period=%s",
            max_major_steps,
            min_discrete_step,
        )
    else:
        # in this case we really have no valuable information on which to make an
        # educated guess. who knows how many events might occurr!!!
        # users will have to iterate.
        max_major_steps = 200
        logger.info(
            "max_major_steps=%s by default since no discrete period in system",
            max_major_steps,
        )
    return max_major_steps

simulate(system, context, tspan, options=None, results_options=None, recorded_signals=None, postprocess=True)

Simulate the hybrid dynamical system defined by system.

The parameters and initial state are defined by context. The simulation time runs from tspan[0] to tspan[1].

The simulation is "hybrid" in the sense that it handles dynamical systems with both discrete and continuous components. The continuous components are integrated using an ODE solver, while discrete components are updated periodically as specified by the individual system components. The continuous and discrete states can also be modified by "zero-crossing" events, which trigger when scalar-valued guard functions cross zero in a specified direction.

The simulation is thus broken into "major" steps, which consist of the following, in order:

(1) Perform any periodic updates to the discrete state. (2) Check if the discrete update triggered any zero-crossing events and handle associated reset maps if necessary. (3) Advance the continuous state using an ODE solver until the next discrete update or zero-crossing, localizing the zero-crossing with a bisection search. (4) Store the results data. (5) If the ODE solver terminated due to a zero-crossing, handle the reset map.

The steps taken by the ODE solver are "minor" steps in this simulation. The behavior of the ODE solver and the hybrid simulation in general can be controlled by configuring SimulatorOptions. Available settings are as follows:

SimulatorOptions

enable_tracing (bool): Allow JAX tracing for JIT compilation max_major_step_length (float): Maximum length of a major step max_major_steps (int): The maximum number of major steps to take in the simulation. This is necessary for automatic differentiation - otherwise the "while" loop is non-differentiable. With the default value of None, a heuristic is used to determine the maximum number of steps based on the periodic update events and time interval. rtol (float): Relative tolerance for the ODE solver. Default is 1e-6. atol (float): Absolute tolerance for the ODE solver. Default is 1e-8. min_minor_step_size (float): Minimum step size for the ODE solver. max_minor_step_size (float): Maximum step size for the ODE solver. ode_solver_method (str): The DE solver to use. Default is "auto", which will use the Dopri5/Jax if JAX tracing is enabled, otherwise the SciPy Dopri5 solver. save_time_series (bool): This option determines whether the simulator saves any data. If the simulation is initiated from simulate this will be set automatically depending on whether recorded_signals is provided. Hence, this should not need to be manually configured. recorded_signals (dict[str, OutputPort]): Dictionary of ports or other cache sources for which the time series should be recorded. Note that if the simulation is initiated from simulate and recorded_signals is provided as a kwarg to simulate, anything set here will be overridden. Hence, this should not need to be manually configured. return_context (bool): If the context is not needed for anything, opting to not return it can speed up compilation times. For instance, typical simulation calls from the UI don't use the context for anything, so model_interface.py will set return_context=False for performance. postprocess (bool): If using buffered results recording (i.e. with JAX numerical backend), this determines whether to automatically trim the buffer after the simulation is complete. This is the default behavior, which will serve unless the full call to simulate needs to be traced (e.g. with grad or vmap).

The return value is a SimulationResults object, which is a named tuple containing all recorded signals as well as the final context (if options.return_context is True). Signals can be recorded by providing a dict of (name, signal_source) pairs Typically the signal sources will be output ports, but they can actually be any SystemCallback object in the system.

Parameters:

Name Type Description Default
system SystemBase

The hybrid dynamical system to simulate.

required
context ContextBase

The initial state and parameters of the system.

required
tspan tuple[float, float]

The start and end times of the simulation.

required
options SimulatorOptions

Options for the simulation process and ODE solver.

None
results_options ResultsOptions

Options related to how the outputs are stored, interpolated, and returned.

None
recorded_signals dict[str, OutputPort]

Dictionary of ports for which the time series should be recorded.

None

Returns:

Name Type Description
SimulationResults SimulationResults

A named tuple containing the recorded signals and the final context (if options.return_context is True).

Notes

If recorded_signals is provided as a kwarg, it will override any entry in options.recorded_signals. This will be deprecated in the future in favor of only passing via options.

This function is meant to best handle single independent simulations. Calling this function repeatedly will always trigger a recompilation of the model when using the JAX backend. To avoid this, call advance_to directly.

Source code in collimator/simulation/simulator.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
def simulate(
    system: SystemBase,
    context: ContextBase,
    tspan: tuple[float, float],
    options: SimulatorOptions = None,
    results_options: ResultsOptions = None,
    recorded_signals: dict[str, OutputPort] = None,
    postprocess: bool = True,
) -> SimulationResults:
    """Simulate the hybrid dynamical system defined by `system`.

    The parameters and initial state are defined by `context`.  The simulation time
    runs from `tspan[0]` to `tspan[1]`.

    The simulation is "hybrid" in the sense that it handles dynamical systems with both
    discrete and continuous components.  The continuous components are integrated using
    an ODE solver, while discrete components are updated periodically as specified by
    the individual system components. The continuous and discrete states can also be
    modified by "zero-crossing" events, which trigger when scalar-valued guard
    functions cross zero in a specified direction.

    The simulation is thus broken into "major" steps, which consist of the following,
    in order:

    (1) Perform any periodic updates to the discrete state.
    (2) Check if the discrete update triggered any zero-crossing events and handle
        associated reset maps if necessary.
    (3) Advance the continuous state using an ODE solver until the next discrete
        update or zero-crossing, localizing the zero-crossing with a bisection search.
    (4) Store the results data.
    (5) If the ODE solver terminated due to a zero-crossing, handle the reset map.

    The steps taken by the ODE solver are "minor" steps in this simulation.  The
    behavior of the ODE solver and the hybrid simulation in general can be controlled
    by configuring `SimulatorOptions`.  Available settings are as follows:

    SimulatorOptions:
        enable_tracing (bool): Allow JAX tracing for JIT compilation
        max_major_step_length (float): Maximum length of a major step
        max_major_steps (int):
            The maximum number of major steps to take in the simulation. This is
            necessary for automatic differentiation - otherwise the "while" loop
            is non-differentiable.  With the default value of None, a heuristic
            is used to determine the maximum number of steps based on the periodic
            update events and time interval.
        rtol (float): Relative tolerance for the ODE solver. Default is 1e-6.
        atol (float): Absolute tolerance for the ODE solver. Default is 1e-8.
        min_minor_step_size (float): Minimum step size for the ODE solver.
        max_minor_step_size (float): Maximum step size for the ODE solver.
        ode_solver_method (str): The DE solver to use.  Default is "auto", which
            will use the Dopri5/Jax if JAX tracing is enabled, otherwise the
            SciPy Dopri5 solver.
        save_time_series (bool):
            This option determines whether the simulator saves any data.  If the
            simulation is initiated from `simulate` this will be set automatically
            depending on whether `recorded_signals` is provided.  Hence, this
            should not need to be manually configured.
        recorded_signals (dict[str, OutputPort]):
            Dictionary of ports or other cache sources for which the time series should
            be recorded. Note that if the simulation is initiated from `simulate` and
            `recorded_signals` is provided as a kwarg to `simulate`, anything set here
            will be overridden.  Hence, this should not need to be manually configured.
        return_context (bool):
            If the context is not needed for anything, opting to not return it can
            speed up compilation times.  For instance, typical simulation calls from
            the UI don't use the context for anything, so model_interface.py will
            set `return_context=False` for performance.
        postprocess (bool):
            If using buffered results recording (i.e. with JAX numerical backend), this
            determines whether to automatically trim the buffer after the simulation is
            complete. This is the default behavior, which will serve unless the full
            call to `simulate` needs to be traced (e.g. with `grad` or `vmap`).

    The return value is a `SimulationResults` object, which is a named tuple containing
    all recorded signals as well as the final context (if `options.return_context` is
    `True`). Signals can be recorded by providing a dict of (name, signal_source) pairs
    Typically the signal sources will be output ports, but they can actually be any
    `SystemCallback` object in the system.

    Args:
        system (SystemBase): The hybrid dynamical system to simulate.
        context (ContextBase): The initial state and parameters of the system.
        tspan (tuple[float, float]): The start and end times of the simulation.
        options (SimulatorOptions): Options for the simulation process and ODE solver.
        results_options (ResultsOptions): Options related to how the outputs are
            stored, interpolated, and returned.
        recorded_signals (dict[str, OutputPort]):
            Dictionary of ports for which the time series should be recorded.

    Returns:
        SimulationResults: A named tuple containing the recorded signals and the final
            context (if `options.return_context` is `True`).

    Notes:
        If `recorded_signals` is provided as a kwarg, it will override any entry in
        `options.recorded_signals`. This will be deprecated in the future in favor of
        only passing via `options`.

        This function is meant to best handle single independent simulations.
        Calling this function repeatedly will always trigger a recompilation of the
        model when using the JAX backend. To avoid this, call advance_to directly.
    """

    options = _check_options(system, options, tspan, recorded_signals)

    if results_options is None:
        results_options = ResultsOptions()

    if results_options.mode != ResultsMode.auto:
        raise NotImplementedError(
            f"Simulation output mode {results_options.mode.name} is not supported. "
            "Only 'auto' is presently supported."
        )

    if system.has_dirty_static_parameters:
        raise ValueError(
            "Some static parameters have been updated. Please create a new context."
        )

    # HACK: Wildcat presently does not use interpolant to produce
    # results sample between minor_step end times, so we clamp
    # the minor step size to the max_results_interval instead.
    if (
        results_options.max_results_interval is not None
        and results_options.max_results_interval > 0
        and results_options.max_results_interval < options.max_minor_step_size
    ):
        options = dataclasses.replace(
            options,
            max_minor_step_size=results_options.max_results_interval,
        )
        logger.info(
            "max_minor_step_size reduced to %s to match max_results_interval",
            options.max_minor_step_size,
        )

    ode_solver = ODESolver(system, options=options.ode_options)

    sim = Simulator(system, ode_solver=ode_solver, options=options)
    logger.info("Simulator ready to start: %s, %s", options, ode_solver)

    # Define a function to be traced by JAX, if allowed, closing over the
    # arguments to `_simulate`.
    def _wrapped_simulate() -> tuple[ContextBase, ResultsData]:
        t0, tf = tspan
        initial_context = context.with_time(t0)
        sim_state = sim.advance_to(tf, initial_context)
        error_end_time_not_reached(
            tf, sim_state.context.time, sim_state.step_end_reason
        )
        final_context = sim_state.context if options.return_context else None
        return final_context, sim_state.results_data

    # JIT-compile the simulation, if allowed
    if options.enable_tracing:
        _wrapped_simulate = jax.jit(_wrapped_simulate)
        _wrapped_simulate = Profiler.jaxjit_profiledfunc(
            _wrapped_simulate, "_wrapped_simulate"
        )

    # Run the simulation
    try:
        final_context, results_data = _wrapped_simulate()

        if postprocess and results_data is not None:
            time, outputs = results_data.finalize()
        else:
            time, outputs = None, None

    finally:
        system.post_simulation_finalize()

    # Reset the integer time scale to the default value in case we decreased precision
    # to reach the end time of a long simulation.  Typically this won't do anything.
    if options.int_time_scale is not None:
        IntegerTime.set_default_scale()

    return SimulationResults(
        final_context,
        time=time,
        outputs=outputs,
    )