451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243 | class Simulator:
"""Class for orchestrating simulations of hybrid dynamical systems.
See the `simulate` function for more details.
"""
def __init__(
self,
system: SystemBase,
ode_solver: ODESolverBase = None,
options: SimulatorOptions = None,
):
"""Initialize the simulator.
Args:
system (SystemBase): The hybrid dynamical system to simulate.
ode_solver (ODESolverBase):
The ODE solver to use for integrating the continuous-time component
of the system. If not provided, a default solver will be used.
options (SimulatorOptions):
Options for the simulation process. See `simulate` for details.
"""
self.system = system
if options is None:
options = SimulatorOptions()
# Determine whether JAX tracing can be used (jit, grad, vmap, etc)
math_backend, self.enable_tracing = _check_backend(options)
# Set the math backend
cnp.set_backend(math_backend)
# Should the simulation be run with autodiff enabled? This will override
# the `advance_to` method with a custom autodiff rule.
self.enable_autodiff = options.enable_autodiff
if ode_solver is None:
ode_solver = ODESolver(system, options=options.ode_options)
# Store configuration options
self.buffer_length = options.buffer_length
self.max_major_steps = options.max_major_steps
self.max_major_step_length = options.max_major_step_length
self.save_time_series = options.save_time_series
self.recorded_outputs = options.recorded_signals
self.zc_bisection_loop_count = options.zc_bisection_loop_count
self.major_step_callback = options.major_step_callback
if self.max_major_step_length is None:
self.max_major_step_length = np.inf
logger.debug("Simulator created with enable_tracing=%s", self.enable_tracing)
self.ode_solver = ode_solver
# Modify the default autodiff rule slightly to correctly capture variations
# in end time of the simulation interval.
self.has_terminal_events = system.zero_crossing_events.has_terminal_events
self.advance_to = self._override_advance_to_vjp()
# Also override the guarded ODE integration with a custom autodiff rule
# to capture variations due to zero-crossing time.
self.guarded_integrate = self._override_guarded_integrate_vjp()
def while_loop(self, cond_fun, body_fun, val):
"""Structured control flow primitive for a while loop.
Dispatches to a bounded while loop as necessary for autodiff. Otherwise
it will call the "backend" implementation, which will either be
`lax.while_loop` when JAX is the backend, or a pure-Python implementation
when using NumPy.
"""
# If autodiff is enabled, we need to use a custom while loop with a maximum
# number of steps so that the loop is reverse-mode differentiable.
# Otherwise we can use a standard unbounded while loop with lax backend.
if self.enable_autodiff:
return _bounded_while_loop(cond_fun, body_fun, val, self.max_major_steps)
else:
return backend.while_loop(cond_fun, body_fun, val)
def initialize(self, context: ContextBase) -> SimulatorState:
"""Perform initial setup for the simulation."""
logger.debug("Initializing simulator")
# context.state.pprint(logger.debug)
# Initial simulation time as integer (picoseconds)
initial_int_time = IntegerTime.from_decimal(context.time)
# Ensure that _next_update_time() can return the current time by perturbing
# current time as slightly toward negative infinity as possible
time_of_next_timed_event, timed_events = _next_update_time(
self.system.periodic_events, initial_int_time - 1
)
# timed_events is now marked with the active events at the next update time
logger.debug("Time of next timed event (int): %s", time_of_next_timed_event)
logger.debug(
"Time of next event (sec): %s",
IntegerTime.as_decimal(time_of_next_timed_event),
)
timed_events.pprint(logger.debug)
end_reason = cnp.where(
time_of_next_timed_event == initial_int_time,
StepEndReason.TimeTriggered,
StepEndReason.NothingTriggered,
)
# Initialize the results data that will hold recorded time series data.
if self.save_time_series:
results_data = ResultsData.initialize(
context, self.recorded_outputs, self.buffer_length
)
else:
results_data = None
return SimulatorState(
context=context,
timed_events=timed_events,
step_end_reason=end_reason,
int_time=initial_int_time,
results_data=results_data,
ode_solver_state=self.ode_solver.initialize(context),
)
def save_results(
self, results_data: ResultsData, context: ContextBase
) -> ResultsData:
"""Update the results data with the current state of the system."""
if not self.save_time_series:
return results_data
return results_data.update(context)
def _override_advance_to_vjp(self) -> Callable:
"""Construct the `advance_to` method for the simulator.
See the docstring for `Simulator._advance_to` for details.
If JAX tracing is enabled for autodiff, wrap the advance function with a
custom autodiff rule to correctly capture variation with respect to end
time. If somehow autodiff works with tracing disabled, the derivatives will
not account for possible variations in end time (for instance in finding
limit cycles or when there are terminal conditions on the simulation).
"""
if not self.enable_autodiff:
return self._advance_to
# This is the main function call whose autodiff rule will be overridden.
def _wrapped_advance_to(
sim: Simulator, boundary_time: float, context: ContextBase
) -> SimulatorState:
return sim._advance_to(boundary_time, context)
# The "forwards pass" to advance the simulation. Also stores the nominal
# VJP calculation and the continuous time derivative value, both of which
# will be needed in the backwards pass.
def _wrapped_advance_to_fwd(
sim: Simulator, boundary_time: float, context: ContextBase
) -> tuple[SimulatorState, tuple]:
primals, vjp_fun = jax.vjp(sim._advance_to, boundary_time, context)
# Also need to keep the final continuous time derivative value for
# computing the adjoint time variable
xdot = sim.system.eval_time_derivatives(primals.context)
# "Residual" information needed in the backwards pass
res = (vjp_fun, xdot, primals.step_end_reason)
return primals, res
def _wrapped_advance_to_adj(
_sim: Simulator, res: tuple, adjoints: SimulatorState
) -> tuple[float, ContextBase]:
# Unpack the residuals from the forward pass
vjp_fun, xdot, reason = res
# Compute whatever the standard adjoint variables are using the
# automatically derived VJP function. The first return variable will
# be the automatically computed tf_adj value, which we will ignore in
# favor of the manually derived value computed next.
_, context_adj = vjp_fun(adjoints)
# The derivative with respect to end time is just the dot product of
# the adjoint "seed" continuous state with the final time derivatives.
# We can overwrite whatever the calculated adjoint time was with this.
vc = adjoints.context.continuous_state
vT_xdot = jax.tree_util.tree_map(
lambda xdot, vc: jnp.dot(xdot, vc), xdot, vc
)
# On the other hand, if the simulation ends early due to a terminal
# event, then the derivative with respect to end time is zero.
tf_adj = jnp.where(
reason == StepEndReason.TerminalEventTriggered,
0.0,
sum(jax.tree_util.tree_leaves(vT_xdot)),
)
# Return adjoints to match the inputs to _wrapped_advance_to, except for
# the first argument (Simulator), which will be marked nondifferentiable.
return (tf_adj, context_adj)
advance_to = jax.custom_vjp(_wrapped_advance_to, nondiff_argnums=(0,))
advance_to.defvjp(_wrapped_advance_to_fwd, _wrapped_advance_to_adj)
# Copy the docstring and type hints to the overridden function
advance_to.__doc__ = self._advance_to.__doc__
advance_to.__annotations__ = self._advance_to.__annotations__
return partial(advance_to, self)
def _guarded_integrate(
self,
solver_state: ODESolverState,
results_data: ResultsData,
tf: float,
context: ContextBase,
zc_events: EventCollection,
) -> tuple[bool, ODESolverState, ContextBase, ResultsData, EventCollection]:
"""Guarded ODE integration.
Advance continuous time using an ODE solver, localizing any zero-crossing events
that occur during the requested interval. If any zero-crossing events trigger,
the dense interpolant is used to localize the events and the associated reset maps
are handled. The method then returns, guaranteeing that the major step terminates
either at the end of the requested interval or at the time of a zero-crossing
event.
Args:
solver_state (ODESolverState): The current state of the ODE solver.
results_data (ResultsData): The results data that will hold recorded time
series data.
tf (float): The end time of the integration interval.
context (ContextBase): The current state of the system.
zc_events (EventCollection): The current zero-crossing events.
Returns:
tuple[bool, ODESolverState, ContextBase, ResultsData, EventCollection]:
A tuple containing the following:
- A boolean indicating whether the major step was terminated early due to
a zero-crossing event.
- The updated state of the ODE solver.
- The updated state of the system.
- The updated results data.
- The updated zero-crossing events.
"""
solver = self.ode_solver
func = solver.flat_ode_rhs # Raveled ODE RHS function
# Close over the additional arguments so that the RHS function has the
# signature `func(y, t)`.
def _func(y, t):
return func(y, t, context)
def _localize_zc_minor(
solver_state, context_t0, context_tf, zc_events, results_data
):
# Using the ODE solver interpolant, employ bisection to find a 'small' time
# interval within which the earliest zero crossing occurrs. See
# _bisection_step_fun for details about how bisection is employed for
# localizing the zero crossing in time.
int_t1 = IntegerTime.from_decimal(context_tf.time)
int_t0 = IntegerTime.from_decimal(context_t0.time)
_body_fun = partial(_bisection_step_fun, solver_state)
carry = GuardIsolationData(int_t0, int_t1, zc_events, context_tf)
search_data = backend.fori_loop(
0, self.zc_bisection_loop_count, _body_fun, carry
)
context_tf = search_data.context
zc_events = search_data.guards
# record results sample for the ZC having 'occurred'
minor_step_end_time = IntegerTime.as_decimal(int_t1)
minor_step_start_time = IntegerTime.as_decimal(int_t0)
zc_occur_time = context_tf.time - (
minor_step_end_time - minor_step_start_time
) / (2 ** (self.zc_bisection_loop_count + 1))
context_zc_time = context_tf.with_time(zc_occur_time)
results_data = self.save_results(results_data, context_zc_time)
# Handle any triggered zero-crossing events
context_tf = self.system.handle_zero_crossings(zc_events, context_tf)
# Re-initialize the solver, since the state may have been reset
# Keep the last step size, since there's no reason to assume that the
# dynamics have changed significantly as a result of the event. If that is
# the case, then we're relying on the adaptive stepping to re-calibrate.
#
# NOTE: this previously only updated the state and time of
# the solver state, but with multistep solvers (e.g. BDF), the
# solver needs to be fully reinitialized because the history
# of differences needs to be cleared and rebuilt over the next few steps.
solver_state = solver.initialize(context_tf)
return solver_state, context_tf, zc_events, results_data
def _no_events_fun(
solver_state, context_t0, context_tf, zc_events, results_data
):
return solver_state, context_tf, zc_events, results_data
def _ode_step(carry):
_, solver_state, context_t0, results_data, zc_events = carry
# Save results at the top of the loop. This will save data at t=t0,
# but not at t=tf. This is okay, since we will save the results at
# the top of the next major step, as well as at the end of the main
# simulation loop.
results_data = self.save_results(results_data, context_t0)
zc_events = guard_interval_start(zc_events, context_t0)
# Advance ODE solver
solver_state = solver.step(_func, tf, solver_state)
xc = solver_state.unraveled_state
context = context_t0.with_time(solver_state.t).with_continuous_state(xc)
# Check for zero-crossing events
zc_events = determine_triggered_guards(zc_events, context)
triggered = zc_events.has_triggered
args = (solver_state, context_t0, context, zc_events, results_data)
solver_state, context, zc_events, results_data = backend.cond(
triggered, _localize_zc_minor, _no_events_fun, *args
)
return (triggered, solver_state, context, results_data, zc_events)
def _cond_fun(carry):
triggered, solver_state, _, _, _ = carry
return (solver_state.t < tf) & (~triggered)
carry = (False, solver_state, context, results_data, zc_events)
triggered, solver_state, context, results_data, zc_events = backend.while_loop(
_cond_fun,
_ode_step,
carry,
)
return triggered, solver_state, context, results_data, zc_events
def _override_guarded_integrate_vjp(self):
if not self.enable_autodiff:
return self._guarded_integrate
def _wrapped_solve(
self: Simulator, solver_state, results_data, tf, context, zc_events
):
return self._guarded_integrate(
solver_state, results_data, tf, context, zc_events
)
def _wrapped_solve_fwd(
self: Simulator, solver_state, _results_data, tf, context, zc_events
):
# Run the forward major step as usual (primal calculation). Do not save
# any results here
# The return from the forward call has the state post-reset, including the
# solver state, the context, and the zero-crossing events.
t0 = solver_state.t
(
triggered,
solver_state_out,
context_out,
_,
zc_events_out,
) = self._guarded_integrate(solver_state, None, tf, context, zc_events)
tf = solver_state_out.t
# Define a differentiable function for the forward pass, knowing where the
# zero-crossing occurs. Note that `tf` here should be the _actual_ interval
# end time, not the requested end time.
solver = self.ode_solver
func = solver.flat_ode_rhs # Raveled ODE RHS function
def _forward(solver_state, tf, context, zc_events_out):
solver_state_out = _odeint(solver, func, solver_state, tf, context)
context = context.with_time(solver_state_out.t)
context = context.with_continuous_state(
solver_state_out.unraveled_state
)
context = self.system.handle_zero_crossings(zc_events_out, context)
return context
# Get the VJP for the event handling
_primals, vjp_fun = jax.vjp(
_forward, solver_state, tf, context, zc_events_out
)
primals = (triggered, solver_state_out, context_out, None, zc_events_out)
residuals = (triggered, solver_state_out, t0, tf, context, vjp_fun)
return primals, residuals
def _wrapped_solve_adj(self: Simulator, residuals, adjoints):
triggered, primal_solver_state, t0, tf, context, vjp_fun = residuals
(
_triggered_adj,
solver_state_adj,
context_adj,
_results_data_adj,
_zc_events_adj,
) = adjoints
context_adj = context_adj.with_time(solver_state_adj.t)
context_adj = context_adj.with_continuous_state(
solver_state_adj.unraveled_state
)
# The `_forward` function corresponding to `vjp_fun` has the signature
# `context_out = _forward(solver_state, tf, context, zc_events_out)`.
# For the adjoint, we have to call with `context_adj` as the input:
solver_state_adj, tf_adj, context_adj, zc_events_adj = vjp_fun(context_adj)
# The Jacobian with respect to the final time is just the time derivative of
# the state at the final time.
yf = primal_solver_state.y
yf_bar = solver_state_adj.y
func = self.ode_solver.flat_ode_rhs # Raveled ODE RHS function
tf_adj = jnp.dot(func(yf, tf, context), yf_bar)
tf_adj = jnp.where(
triggered,
tf_adj,
jnp.zeros_like(tf_adj),
)
return (solver_state_adj, None, tf_adj, context_adj, zc_events_adj)
guarded_integrate = jax.custom_vjp(_wrapped_solve, nondiff_argnums=(0,))
guarded_integrate.defvjp(_wrapped_solve_fwd, _wrapped_solve_adj)
# Copy the docstring and type hints to the overridden function
guarded_integrate.__doc__ = self._guarded_integrate.__doc__
guarded_integrate.__annotations__ = self._guarded_integrate.__annotations__
return partial(guarded_integrate, self)
def _advance_continuous_time(
self,
cdata: ContinuousIntervalData,
) -> ContinuousIntervalData:
"""Advance the simulation to the next discrete update or zero-crossing event.
This stores the values of all active guard functions and advances the
continuous-time component of the system to the next discrete update or
zero-crossing event, whichever comes first. Zero-crossing events are
localized using a bisection search defined by `_trigger_search`, which will
also record the final guard function values at the end of the search interval
and determine which (if any) zero-crossing events were triggered.
"""
# Unpack inputs
int_tf = cdata.tf
context = cdata.context
results_data = cdata.results_data
zc_events = self.system.determine_active_guards(context)
if self.system.has_continuous_state:
solver_state = cdata.ode_solver_state
tf = IntegerTime.as_decimal(int_tf)
(
triggered,
solver_state,
context,
results_data,
zc_events,
) = self.guarded_integrate(
solver_state,
results_data,
tf,
context,
zc_events,
)
context = context.with_time(solver_state.t)
context = context.with_continuous_state(solver_state.unraveled_state)
# Converting from decimal -> integer time incurs a loss of precision. This is
# okay for unscheduled zero-crossing events, but problematic for timed events.
# So only do this conversion if a zero-crossing was triggered. Otherwise we
# know we have reached the end of the interval and can keep the requested end
# time.
int_tf = cnp.where(
triggered,
IntegerTime.from_decimal(context.time),
int_tf,
)
else:
# Skip the ODE solver for systems without continuous state. We still
# have to check for triggered events here in case there are any
# transitions triggered by time that need to be handled before the
# periodic discrete update at the top of the next major step
triggered = False
solver_state = cdata.ode_solver_state
zc_events = guard_interval_start(zc_events, context)
results_data = self.save_results(results_data, context)
# Advance time to the end of the interval
context = context.with_time(IntegerTime.as_decimal(int_tf))
# Record guard values after the discrete update and check if anything
# triggered as a result of advancing time
zc_events = guard_interval_end(zc_events, context)
zc_events = determine_triggered_guards(zc_events, context)
# Handle any triggered zero-crossing events
context = self.system.handle_zero_crossings(zc_events, context)
# Even though the zero-crossing events have already been "handled", the
# information about whether a terminal event has been triggered is still in
# the events collection (since "triggered" has not been cleared by a call
# to determine_triggered_guards).
terminate_early = zc_events.has_active_terminal
return cdata._replace(
triggered=triggered,
terminate_early=terminate_early,
context=context,
tf=int_tf,
results_data=results_data,
ode_solver_state=solver_state,
)
def _handle_discrete_update(
self, context: ContextBase, timed_events: EventCollection
) -> tuple[ContextBase, bool]:
"""Handle discrete updates triggered by time.
This method is called at the beginning of each major step to handle any
discrete updates that are triggered by time. This includes both discrete
updates that are triggered by time and any zero-crossing events that are
triggered by the discrete update.
This will also work when there are no zero-crossing events: the zero-crossing
collection will be empty and only the periodic discrete update will happen.
Args:
context (ContextBase): The current state of the system.
timed_events (EventCollection):
The collection of timed events, with the active events marked.
Returns:
ContextBase: The updated state of the system.
bool: Whether the simulation should terminate early as a result of a
triggered terminal condition.
"""
system = self.system
# Get the collection of zero-crossing events that _might_ be activated
# given the current state of the system. For example, some events may
# be de-activated as a result of the current state of a state machine.
zc_events = system.determine_active_guards(context)
# Record guard values at the start of the interval
zc_events = guard_interval_start(zc_events, context)
# Handle any active periodic discrete updates
context = system.handle_discrete_update(timed_events, context)
# Record guard values after the discrete update
zc_events = guard_interval_end(zc_events, context)
# Check if guards have triggered as a result of these updates
zc_events = determine_triggered_guards(zc_events, context)
terminate_early = zc_events.has_active_terminal
# Handle any zero-crossing events that were triggered
context = system.handle_zero_crossings(zc_events, context)
return context, terminate_early
# This method is marked private because it will be wrapped with a custom autodiff
# rule to get the correct derivatives with respect to the end time of the
# simulation interval using `_override_advance_to_vjp`. This also copies the
# docstring to the overridden function. Normally the wrapped attribute `advance_to`
# is what should be called by users.
def _advance_to(self, boundary_time: float, context: ContextBase) -> SimulatorState:
"""Core control flow logic for running a simulation.
This is the main loop for advancing the simulation. It is called by `simulate`
or can be called directly if more fine-grained control is needed. This method
essentially loops over "major steps" until the boundary time is reached. See
the documentation for `simulate` for details on the order of operations in a
major step.
Args:
boundary_time (float): The time to advance to.
context (ContextBase): The current state of the system.
Returns:
SimulatorState:
A named tuple containing the final state of the simulation, including
the final context, a collection of pending timed events, and a flag
indicating the reason that the most recent major step ended.
Notes:
API will change slightly as a result of WC-87, which will break out the
initialization from the main loop so that `advance_to` can be called
repeatedly. See:
https://collimator.atlassian.net/browse/WC-87
"""
system = self.system
sim_state = self.initialize(context)
end_reason = sim_state.step_end_reason
context = sim_state.context
timed_events = sim_state.timed_events
int_boundary_time = IntegerTime.from_decimal(boundary_time)
# We will be limiting each step by the max_major_step_length. However, if this
# is infinite we should just use the end time of the simulation to avoid
# integer overflow. This could be problematic if the end time of the
# simulation is close to the maximum representable integer time, but we can come
# back to that if it's an issue.
int_max_step_length = IntegerTime.from_decimal(
cnp.minimum(self.max_major_step_length, boundary_time)
)
# Only activate timed events if the major step ended on a time trigger
timed_events = activate_timed_events(timed_events, end_reason)
# Called on the "True" branch of the conditional
def _major_step(sim_state: SimulatorState) -> SimulatorState:
end_reason = sim_state.step_end_reason
context = sim_state.context
timed_events = sim_state.timed_events
int_time = sim_state.int_time
if not self.enable_tracing:
logger.debug("Starting a simulation step at t=%s", context.time)
logger.debug(" merged_events: %s", timed_events)
# Handle any discrete updates that are triggered by time along with
# any zero-crossing events that are triggered by the discrete update.
context, terminate_early = self._handle_discrete_update(
context, timed_events
)
logger.debug("Terminate early after discrete update: %s", terminate_early)
# How far can we go before we have to handle timed events?
# The time returned here will be the integer time representation.
time_of_next_timed_event, timed_events = _next_update_time(
system.periodic_events, int_time
)
if not self.enable_tracing:
logger.debug(
"Next timed event at t=%s",
IntegerTime.as_decimal(time_of_next_timed_event),
)
timed_events.pprint(logger.debug)
# Determine whether the events include a timed update
update_time = IntegerTime.max_int_time
if timed_events.num_events > 0:
update_time = time_of_next_timed_event
# Limit the major step end time to the simulation end time, major step limit,
# or next periodic update time.
# This is the mechanism used to advance time for systems that have
# no states and no periodic events.
# Discrete systems] when there are discrete periodic events, we use those
# to determine each major step end time.
# Feedthrough system] when there are just feedthrough blocks (no states or
# events), use max_major_step_length to determine each major step end time.
int_tf_limit = int_time + int_max_step_length
int_tf = cnp.min(
cnp.array(
[
int_boundary_time,
int_tf_limit,
update_time,
]
)
)
if not self.enable_tracing:
logger.debug(
"Expecting to integrate to t=%s",
IntegerTime.as_decimal(int_tf),
)
# Normally we will advance continuous time to the end of the major step
# here. However, if a terminal event was triggered as part of the discrete
# update, we should respect that and skip the continuous update.
#
# Construct the container used to hold various data related to advancing
# continuous time. This is passed to ODE solvers, zero-crossing
# localization, and related functions.
cdata = ContinuousIntervalData(
context=context,
terminate_early=terminate_early,
triggered=False,
t0=int_time,
tf=int_tf,
results_data=sim_state.results_data,
ode_solver_state=sim_state.ode_solver_state,
)
cdata = backend.cond(
(self.has_terminal_events & cdata.terminate_early),
lambda cdata: cdata, # Terminal event triggered - return immediately
self._advance_continuous_time, # Advance continuous time normally
cdata,
)
# Unpack the results of the continuous time advance
context = cdata.context
terminate_early = cdata.terminate_early
triggered = cdata.triggered
int_tf = cdata.tf
results_data = cdata.results_data
ode_solver_state = cdata.ode_solver_state
# Determine the reason why the major step ended. Did a zero-crossing
# trigger, did a timed event trigger, neither, or both?
# terminate_early = terminate_early | zc_events.has_active_terminal
logger.debug("Terminate early after major step: %s", terminate_early)
end_reason = _determine_step_end_reason(
triggered, terminate_early, int_tf, update_time
)
logger.debug("Major step end reason: %s", end_reason)
# Conditionally activate timed events depending on whether the major step
# ended as a result of a time trigger or zero-crossing event.
timed_events = activate_timed_events(timed_events, end_reason)
if self.major_step_callback:
io_callback(self.major_step_callback, (), context.time)
return SimulatorState(
step_end_reason=end_reason,
context=context,
timed_events=timed_events,
int_time=int_tf,
results_data=results_data,
ode_solver_state=ode_solver_state,
)
def _cond_fun(sim_state: SimulatorState):
return (sim_state.int_time < int_boundary_time) & (
sim_state.step_end_reason != StepEndReason.TerminalEventTriggered
)
# Initialize the "carry" values for the main loop.
sim_state = SimulatorState(
context=context,
timed_events=timed_events,
step_end_reason=end_reason,
int_time=sim_state.int_time,
results_data=sim_state.results_data,
ode_solver_state=sim_state.ode_solver_state,
)
logger.debug(
"Running simulation from t=%s to t=%s", context.time, boundary_time
)
try:
# Main loop call
sim_state = self.while_loop(_cond_fun, _major_step, sim_state)
logger.debug("Simulation complete at t=%s", sim_state.context.time)
except KeyboardInterrupt:
# TODO: flag simulation as interrupted somewhere in sim_state
logger.info("Simulation interrupted at t=%s", sim_state.context.time)
# At the end of the simulation we need to handle any pending discrete updates
# and store the solution one last time.
# FIXME (WC-87): The returned simulator state can't be used with advance_to again,
# since the discrete updates have already been performed. Should be broken out
# into a `finalize` method as part of WC-87.
# update discrete state to x+ at the simulation end_time
if self.save_time_series:
logger.debug("Finalizing solution...")
# 1] do discrete update (will skip if the simulation was terminated early)
context, _terminate_early = self._handle_discrete_update(
sim_state.context, sim_state.timed_events
)
# 2] do update solution
results_data = self.save_results(sim_state.results_data, context)
sim_state = sim_state._replace(
context=context,
results_data=results_data,
)
logger.debug("Done finalizing solution")
return sim_state
|