Skip to content

Framework

collimator.framework

BlockInitializationError

Bases: CollimatorError

A generic error to be thrown when a block fails at init time, but the full exceptions are known to cause issues, eg. with ray serialization.

Source code in collimator/framework/error.py
225
226
227
228
229
230
class BlockInitializationError(CollimatorError):
    """A generic error to be thrown when a block fails at init time, but
    the full exceptions are known to cause issues, eg. with ray serialization.
    """

    pass

BlockParameterError

Bases: StaticError

Block parameters are missing or have invalid values.

Source code in collimator/framework/error.py
173
174
175
176
class BlockParameterError(StaticError):
    """Block parameters are missing or have invalid values."""

    pass

BlockRuntimeError

Bases: CollimatorError

A generic error to be thrown when a block fails at runtime, but the full exceptions are known to cause issues, eg. with ray serialization.

Source code in collimator/framework/error.py
239
240
241
242
243
244
class BlockRuntimeError(CollimatorError):
    """A generic error to be thrown when a block fails at runtime, but
    the full exceptions are known to cause issues, eg. with ray serialization.
    """

    pass

CollimatorError

Bases: Exception

Base class for all custom collimator errors.

Source code in collimator/framework/error.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class CollimatorError(Exception):
    """Base class for all custom collimator errors."""

    # Ideally we'd always have a system to pass but there are at least 2 cases
    # where we may not have one:
    # 1. parsing from json, the block hasn't been built yet
    # 2. other errors not specific to a block
    # In case 1, we should pass all name, path & ui_id info to the error

    def __init__(
        self,
        message=None,
        *,
        system: "SystemBase" = None,  # noqa
        system_id: Hashable = None,
        name_path: list[str] = None,
        ui_id_path: list[str] = None,
        port_index: int = None,
        port_name: str = None,
        port_direction: str = None,  # 'in' or 'out'
        parameter_name: str = None,
        loop: list["DirectedPortLocator"] = None,
    ):
        """Create a new CollimatorError.

        Only `message` is a positional argument, all others are keyword arguments.

        Args:
            message: A custom error message, defaults to the error class name.
            system: The system that the error occurred in, if available.
            system_id: The id of the system that the error occurred in, use if system can't be passed.
            name_path: The name path of the block that the error occurred in, use if system can't be passed.
            ui_id_path: The ui_id (uuid) path of the block that the error occurred in, use if system can't be passed.
            port_index: The index of the port that the error occurred at.
            port_name: The name of the port that the error occurred at.
            port_direction: The direction of the port that the error occurred at.
            parameter_name: The name of the parameter that the error occurred at.
            loop: A list of I/O ports where the error occurred (eg. AlgebraicLoopError).
        """
        super().__init__(message)

        if system and system_id:
            warnings.warn(
                "Should not specify both system and system_id when raising exceptions"
            )

        if system:
            self.system_id = system.system_id
            self.name_path = name_path or system.name_path
            self.ui_id_path = ui_id_path or system.ui_id_path
        else:
            self.system_id = system_id
            self.name_path = name_path
            self.ui_id_path = ui_id_path

        self.message = message
        self.port_index = port_index
        self.port_name = port_name
        self.port_direction = port_direction
        self.parameter_name = parameter_name

        # Extract serializable info from loop
        # NOTE: we could compact it a bit if the JSON becomes too large...
        self.loop: list[LoopItem] = None
        if loop is not None:
            self.loop = [
                LoopItem(
                    name_path=loc[0].name_path,
                    ui_id_path=loc[0].ui_id_path,
                    port_direction=loc[1],
                    port_index=loc[2],
                )
                for loc in loop
            ]

    def __str__(self):
        message = self.message or self.default_message
        return f"{message}{self._context_info()}"

    def _context_info(self) -> str:
        strbuf = []

        if self.name_path:
            # FIXME: this is known to be too verbose when looking at errors from
            # the UI but makes it better when running pytest or from code.
            # For now, be verbose.
            name_path = ".".join(self.name_path)
            strbuf.append(f" in block {name_path}")
        elif self.system_id:  # Unnamed blocks, likely from code
            strbuf.append(f" in system {self.system_id}")

        if self.port_direction:
            strbuf.append(
                f" at {self.port_direction}put port {self.port_name or self.port_index}"
            )
        elif self.port_name:
            strbuf.append(f" at port {self.port_name}")
        elif self.port_index is not None:
            strbuf.append(f" at port {self.port_index}")
        if self.parameter_name:
            strbuf.append(f" with parameter {self.parameter_name}")
        if self.__cause__ is not None:
            strbuf.append(f": {self.__cause__}")

        return "".join(strbuf)

    @property
    def block_name(self):
        if self.name_path is None:
            return None
        if len(self.name_path) == 0:
            return "root"
        return self.name_path[-1]

    @property
    def default_message(self):
        return type(self).__name__

    def caused_by(self, exc_type: type):
        """Check if this error is or was caused by another error type.

        For instance, if a CollimatorError is raised because of a TypeError,
        this method will return True when called with TypeError as exc_type.

        Args:
            exc_type: The type of exception to check for (eg. TypeError)

        Returns:
            bool: True if the error is or was caused by the given exception type.
        """

        def _is_or_caused_by(exc, cause_type) -> bool:
            if not exc or not cause_type:
                return False
            if isinstance(exc, cause_type):
                return True
            if not hasattr(self, "__cause__"):
                return False
            return _is_or_caused_by(exc.__cause__, cause_type)

        return _is_or_caused_by(self, exc_type)

__init__(message=None, *, system=None, system_id=None, name_path=None, ui_id_path=None, port_index=None, port_name=None, port_direction=None, parameter_name=None, loop=None)

Create a new CollimatorError.

Only message is a positional argument, all others are keyword arguments.

Parameters:

Name Type Description Default
message

A custom error message, defaults to the error class name.

None
system SystemBase

The system that the error occurred in, if available.

None
system_id Hashable

The id of the system that the error occurred in, use if system can't be passed.

None
name_path list[str]

The name path of the block that the error occurred in, use if system can't be passed.

None
ui_id_path list[str]

The ui_id (uuid) path of the block that the error occurred in, use if system can't be passed.

None
port_index int

The index of the port that the error occurred at.

None
port_name str

The name of the port that the error occurred at.

None
port_direction str

The direction of the port that the error occurred at.

None
parameter_name str

The name of the parameter that the error occurred at.

None
loop list[DirectedPortLocator]

A list of I/O ports where the error occurred (eg. AlgebraicLoopError).

None
Source code in collimator/framework/error.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(
    self,
    message=None,
    *,
    system: "SystemBase" = None,  # noqa
    system_id: Hashable = None,
    name_path: list[str] = None,
    ui_id_path: list[str] = None,
    port_index: int = None,
    port_name: str = None,
    port_direction: str = None,  # 'in' or 'out'
    parameter_name: str = None,
    loop: list["DirectedPortLocator"] = None,
):
    """Create a new CollimatorError.

    Only `message` is a positional argument, all others are keyword arguments.

    Args:
        message: A custom error message, defaults to the error class name.
        system: The system that the error occurred in, if available.
        system_id: The id of the system that the error occurred in, use if system can't be passed.
        name_path: The name path of the block that the error occurred in, use if system can't be passed.
        ui_id_path: The ui_id (uuid) path of the block that the error occurred in, use if system can't be passed.
        port_index: The index of the port that the error occurred at.
        port_name: The name of the port that the error occurred at.
        port_direction: The direction of the port that the error occurred at.
        parameter_name: The name of the parameter that the error occurred at.
        loop: A list of I/O ports where the error occurred (eg. AlgebraicLoopError).
    """
    super().__init__(message)

    if system and system_id:
        warnings.warn(
            "Should not specify both system and system_id when raising exceptions"
        )

    if system:
        self.system_id = system.system_id
        self.name_path = name_path or system.name_path
        self.ui_id_path = ui_id_path or system.ui_id_path
    else:
        self.system_id = system_id
        self.name_path = name_path
        self.ui_id_path = ui_id_path

    self.message = message
    self.port_index = port_index
    self.port_name = port_name
    self.port_direction = port_direction
    self.parameter_name = parameter_name

    # Extract serializable info from loop
    # NOTE: we could compact it a bit if the JSON becomes too large...
    self.loop: list[LoopItem] = None
    if loop is not None:
        self.loop = [
            LoopItem(
                name_path=loc[0].name_path,
                ui_id_path=loc[0].ui_id_path,
                port_direction=loc[1],
                port_index=loc[2],
            )
            for loc in loop
        ]

caused_by(exc_type)

Check if this error is or was caused by another error type.

For instance, if a CollimatorError is raised because of a TypeError, this method will return True when called with TypeError as exc_type.

Parameters:

Name Type Description Default
exc_type type

The type of exception to check for (eg. TypeError)

required

Returns:

Name Type Description
bool

True if the error is or was caused by the given exception type.

Source code in collimator/framework/error.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def caused_by(self, exc_type: type):
    """Check if this error is or was caused by another error type.

    For instance, if a CollimatorError is raised because of a TypeError,
    this method will return True when called with TypeError as exc_type.

    Args:
        exc_type: The type of exception to check for (eg. TypeError)

    Returns:
        bool: True if the error is or was caused by the given exception type.
    """

    def _is_or_caused_by(exc, cause_type) -> bool:
        if not exc or not cause_type:
            return False
        if isinstance(exc, cause_type):
            return True
        if not hasattr(self, "__cause__"):
            return False
        return _is_or_caused_by(exc.__cause__, cause_type)

    return _is_or_caused_by(self, exc_type)

ContextBase dataclass

Context object containing state, parameters, etc for a system.

NOTE: Type hints in ContextBase indicate the union between what would be returned by a LeafContext and a DiagramContext. See type hints of the subclasses for the specific argument and return types.

Attributes:

Name Type Description
owning_system SystemBase

The owning system of the context.

time Scalar

The time associated with the context. Will be None unless the context is the root context.

is_initialized bool

Flag indicating if the context is initialized. This should only be set by the ContextFactory during creation.

Source code in collimator/framework/context.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
@dataclasses.dataclass(frozen=True)
class ContextBase(metaclass=abc.ABCMeta):
    """Context object containing state, parameters, etc for a system.

    NOTE: Type hints in ContextBase indicate the union between what would be returned
    by a LeafContext and a DiagramContext. See type hints of the subclasses for
    the specific argument and return types.

    Attributes:
        owning_system (SystemBase):
            The owning system of the context.
        time (Scalar):
            The time associated with the context. Will be None unless the context
            is the root context.
        is_initialized (bool):
            Flag indicating if the context is initialized. This should only be set
            by the ContextFactory during creation.
    """

    owning_system: SystemBase
    time: Scalar = None
    is_initialized: bool = False
    parameters: Mapping[str, Array] = None

    @abc.abstractmethod
    def __getitem__(self, key: Hashable) -> LeafContext:
        """Get the subcontext associated with the given system ID.

        For leaf contexts, this will return `self`, but the method is provided
        so that there is a consistent interface for working with either an
        individual LeafSystem or tree-structured Diagram.

        For nested diagrams, intermediate diagrams do not have associated contexts,
        so indexing will fail.
        """
        pass

    @abc.abstractmethod
    def with_subcontext(self, key: Hashable, ctx: LeafContext) -> ContextBase:
        """Create a copy of this context, replacing the specified subcontext."""
        pass

    def with_time(self, value: Scalar) -> ContextBase:
        """Create a copy of this context, replacing time with the given value.

        This should only be called on the root context, since it is expected that all
        subcontexts will have a time value of None to avoid any conflicts.
        """

        # This looks really odd here, but this seems to be the most correct
        # and efficient place to call the cache invalidation.
        # FIXME: figure out where this actually belongs.
        self.owning_system.invalidate_output_caches()

        return dataclasses.replace(self, time=value)

    @abc.abstractproperty
    def state(self) -> State:
        pass

    @abc.abstractmethod
    def with_state(self, state: State) -> ContextBase:
        """Create a copy of this context, replacing the entire state."""
        pass

    @abc.abstractmethod
    def with_new_state(self) -> ContextBase:
        """Create a copy of this context, replacing the state with a new state."""
        pass

    @abc.abstractproperty
    def continuous_state(self) -> StateComponent:
        pass

    @abc.abstractmethod
    def with_continuous_state(self, value: StateComponent) -> ContextBase:
        """Create a copy of this context, replacing the continuous state."""
        pass

    @abc.abstractproperty
    def num_continuous_states(self) -> int:
        pass

    @abc.abstractproperty
    def has_continuous_state(self) -> bool:
        pass

    @abc.abstractproperty
    def discrete_state(self) -> StateComponent:
        pass

    @abc.abstractmethod
    def with_discrete_state(self, value: StateComponent) -> ContextBase:
        """Create a copy of this context, replacing the discrete state."""
        pass

    @abc.abstractproperty
    def num_discrete_states(self) -> int:
        pass

    @abc.abstractproperty
    def has_discrete_state(self) -> bool:
        pass

    @abc.abstractproperty
    def mode(self) -> Mode:
        pass

    @abc.abstractproperty
    def has_mode(self) -> bool:
        pass

    @abc.abstractmethod
    def with_mode(self, value: Mode) -> ContextBase:
        """Create a copy of this context, replacing the mode."""
        pass

    def mark_initialized(self) -> ContextBase:
        return dataclasses.replace(self, is_initialized=True)

    @abc.abstractmethod
    def with_updated_parameters(self) -> ContextBase:
        """Create a copy of this context, updating all parameters to their current values."""
        pass

    def with_parameter(self, name: str, value: ArrayLike) -> ContextBase:
        """Create a copy of this context, replacing the specified parameter."""
        return self.with_parameters({name: value})

    @abc.abstractmethod
    def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
        """Create a copy of this context, replacing only the specified parameters."""
        pass

__getitem__(key) abstractmethod

Get the subcontext associated with the given system ID.

For leaf contexts, this will return self, but the method is provided so that there is a consistent interface for working with either an individual LeafSystem or tree-structured Diagram.

For nested diagrams, intermediate diagrams do not have associated contexts, so indexing will fail.

Source code in collimator/framework/context.py
 98
 99
100
101
102
103
104
105
106
107
108
109
@abc.abstractmethod
def __getitem__(self, key: Hashable) -> LeafContext:
    """Get the subcontext associated with the given system ID.

    For leaf contexts, this will return `self`, but the method is provided
    so that there is a consistent interface for working with either an
    individual LeafSystem or tree-structured Diagram.

    For nested diagrams, intermediate diagrams do not have associated contexts,
    so indexing will fail.
    """
    pass

with_continuous_state(value) abstractmethod

Create a copy of this context, replacing the continuous state.

Source code in collimator/framework/context.py
148
149
150
151
@abc.abstractmethod
def with_continuous_state(self, value: StateComponent) -> ContextBase:
    """Create a copy of this context, replacing the continuous state."""
    pass

with_discrete_state(value) abstractmethod

Create a copy of this context, replacing the discrete state.

Source code in collimator/framework/context.py
165
166
167
168
@abc.abstractmethod
def with_discrete_state(self, value: StateComponent) -> ContextBase:
    """Create a copy of this context, replacing the discrete state."""
    pass

with_mode(value) abstractmethod

Create a copy of this context, replacing the mode.

Source code in collimator/framework/context.py
186
187
188
189
@abc.abstractmethod
def with_mode(self, value: Mode) -> ContextBase:
    """Create a copy of this context, replacing the mode."""
    pass

with_new_state() abstractmethod

Create a copy of this context, replacing the state with a new state.

Source code in collimator/framework/context.py
139
140
141
142
@abc.abstractmethod
def with_new_state(self) -> ContextBase:
    """Create a copy of this context, replacing the state with a new state."""
    pass

with_parameter(name, value)

Create a copy of this context, replacing the specified parameter.

Source code in collimator/framework/context.py
199
200
201
def with_parameter(self, name: str, value: ArrayLike) -> ContextBase:
    """Create a copy of this context, replacing the specified parameter."""
    return self.with_parameters({name: value})

with_parameters(new_parameters) abstractmethod

Create a copy of this context, replacing only the specified parameters.

Source code in collimator/framework/context.py
203
204
205
206
@abc.abstractmethod
def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
    """Create a copy of this context, replacing only the specified parameters."""
    pass

with_state(state) abstractmethod

Create a copy of this context, replacing the entire state.

Source code in collimator/framework/context.py
134
135
136
137
@abc.abstractmethod
def with_state(self, state: State) -> ContextBase:
    """Create a copy of this context, replacing the entire state."""
    pass

with_subcontext(key, ctx) abstractmethod

Create a copy of this context, replacing the specified subcontext.

Source code in collimator/framework/context.py
111
112
113
114
@abc.abstractmethod
def with_subcontext(self, key: Hashable, ctx: LeafContext) -> ContextBase:
    """Create a copy of this context, replacing the specified subcontext."""
    pass

with_time(value)

Create a copy of this context, replacing time with the given value.

This should only be called on the root context, since it is expected that all subcontexts will have a time value of None to avoid any conflicts.

Source code in collimator/framework/context.py
116
117
118
119
120
121
122
123
124
125
126
127
128
def with_time(self, value: Scalar) -> ContextBase:
    """Create a copy of this context, replacing time with the given value.

    This should only be called on the root context, since it is expected that all
    subcontexts will have a time value of None to avoid any conflicts.
    """

    # This looks really odd here, but this seems to be the most correct
    # and efficient place to call the cache invalidation.
    # FIXME: figure out where this actually belongs.
    self.owning_system.invalidate_output_caches()

    return dataclasses.replace(self, time=value)

with_updated_parameters() abstractmethod

Create a copy of this context, updating all parameters to their current values.

Source code in collimator/framework/context.py
194
195
196
197
@abc.abstractmethod
def with_updated_parameters(self) -> ContextBase:
    """Create a copy of this context, updating all parameters to their current values."""
    pass

DependencyTicket

Singleton class for managing unique dependency tickets.

Source code in collimator/framework/dependency_graph.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class DependencyTicket:
    """Singleton class for managing unique dependency tickets."""

    nothing = 0  # Indicates "not dependent on anything".
    time = 1  # Time.
    xc = 2  # All continuous state variables.
    xd = 3  # All discrete state variables
    mode = 4  # All modes.
    x = 5  # All state variables x = {xc, xd, mode}.
    p = 6  # All parameters
    all_sources_except_input_ports = 7  # Everything except input ports.
    u = 8  # All input ports u.
    all_sources = 9  # All of the above.
    xcdot = 10  # Continuous state time derivative

    _next_available = 11  # This will get incremented by next_available_ticket().

    @classmethod
    def next_available_ticket(cls):
        cls._next_available += 1
        return cls._next_available

Diagram dataclass

Bases: SystemBase

Composite block-diagram representation of a dynamical system.

A Diagram is a collection of Systems connected together to form a larger hybrid dynamical system. Diagrams can be nested to any depth, creating a tree-structured block diagram.

NOTE: The Diagram class is not intended to be constructed directly. Instead, use the DiagramBuilder to construct a Diagram, which will pass the appropriate information to this constructor.

Source code in collimator/framework/diagram.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
@dataclasses.dataclass
class Diagram(SystemBase):
    """Composite block-diagram representation of a dynamical system.

    A Diagram is a collection of Systems connected together to form a larger hybrid
    dynamical system. Diagrams can be nested to any depth, creating a tree-structured
    block diagram.

    NOTE: The Diagram class is not intended to be constructed directly.  Instead,
    use the `DiagramBuilder` to construct a Diagram, which will pass the appropriate
    information to this constructor.
    """

    # Redefine here to make pylint happy
    system_id: Hashable = dataclasses.field(default_factory=next_system_id, init=False)
    name: str = None  # Human-readable name for this system (optional)
    ui_id: str = None  # UUID of the block when loaded from JSON (optional)

    # None of these attributes are intended to be modified or accessed directly after
    # construction.  Instead, use the interface defined by `SystemBase`.

    # Direct children of this Diagram
    nodes: List[SystemBase] = dataclasses.field(default_factory=list)

    # Mapping from input ports to output ports of child subsystems
    connection_map: Mapping[InputPortLocator, OutputPortLocator] = dataclasses.field(
        default_factory=dict,
    )

    # Optional identifier for "reference diagrams"
    ref_id: str = None

    # for serialization
    instance_parameters: set[str] = dataclasses.field(default_factory=set)

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self.name}, {len(self.nodes)} nodes)"

    def _pprint(self, prefix="", fancy=True) -> str:
        if fancy:
            return pprint_fancy(prefix, self)
        return f"{prefix}|-- {self.name}\n"

    def _pprint_helper(self, prefix="", fancy=True) -> str:
        s = self._pprint(prefix=prefix, fancy=fancy)
        for _, substate in enumerate(self.nodes):
            s += substate._pprint_helper(prefix=f"{prefix}    ", fancy=fancy)
        return s

    def __hash__(self) -> Hashable:
        return hash(self.system_id)

    def __getitem__(self, name: str) -> SystemBase:
        # Access by name - for convenient user interface only.  Programmatic
        #  access should use the nodes directly, e.g. `self.nodes[idx]`
        lookup = {node.name: node for node in self.nodes}
        return lookup[name]

    def __iter__(self) -> Iterator[SystemBase]:
        return iter(self.nodes)

    def __post_init__(self):
        super().__post_init__()

        # Set parent for all the immediate child systems
        for node in self.nodes:
            node.parent = self

        # The map of subsystem inputs/outputs to inputs/outputs of this Diagram.
        self._input_port_map: Mapping[InputPortLocator, int] = {}
        self._output_port_map: Mapping[OutputPortLocator, int] = {}

        # Also need the inverse output map, for determining feedthrough paths.
        self._inv_output_port_map: Mapping[int, OutputPortLocator] = {}

        # Leaves of the system tree (not necessarily the same as the direct
        # children of this Diagram, which may themselves be Diagrams)
        self.leaf_systems: List[LeafSystem] = []
        for sys in self.nodes:
            if isinstance(sys, Diagram):
                # FIXME: In case of 'param estimation' optimization run, we end
                # up here and sys.leaf_systems is now None. Using or [] fixes
                # the crash but something is a bit fishy.
                self.leaf_systems.extend(sys.leaf_systems or [])
                # No longer need the child leaf systems, since methods using this
                # should only be called from the top level.
                sys.leaf_systems = None
            else:
                self.leaf_systems.append(sys)

    def post_simulation_finalize(self) -> None:
        """Perform any post-simulation cleanup for this system."""
        for system in self.nodes:
            system.post_simulation_finalize()

    # Inherits docstrings from SystemBase
    @property
    def has_feedthrough_side_effects(self) -> bool:
        # See explanation in `SystemBase.has_feedthrough_side_effects`.
        return any(sys.has_feedthrough_side_effects for sys in self.nodes)

    # Inherits docstrings from SystemBase
    @property
    def has_ode_side_effects(self) -> bool:
        # Return true if either of the following are true:
        # 1. At least one subsystem has ODE side effects
        # 2. At least one subsystem has feedthrough side effects and the output
        #    ports of the diagram are used as ODE inputs.

        if self.dependency_graph is None:
            raise ValueError("Must create dependency graph first.")

        # If no subsystems have feedthrough side effects, we're done.
        if not self.has_feedthrough_side_effects:
            return False

        # If any subsystem is already known to have this property, we're done.
        if any(sys.has_ode_side_effects for sys in self.nodes):
            return True

        # If we get here, we need to actually test the dependency graph.
        for sys in self.nodes:
            if sys.has_feedthrough_side_effects:
                for port in sys.output_ports:
                    tracker = port.tracker
                    if tracker.is_prerequisite_of([DependencyTicket.xcdot]):
                        return True
        return False

    @property
    def has_continuous_state(self) -> bool:
        return any(sys.has_continuous_state for sys in self.nodes)

    @property
    def has_discrete_state(self) -> bool:
        return any(sys.has_discrete_state for sys in self.nodes)

    @property
    def has_zero_crossing_events(self) -> bool:
        return any(sys.has_zero_crossing_events for sys in self.nodes)

    @property
    def num_systems(self) -> int:
        # Number of subsystems _at this level_
        return len(self.nodes)

    def check_types(
        self,
        context: DiagramContext,
        error_collector: ErrorCollector = None,
    ) -> None:
        """Perform any system-specific static analysis."""
        for system in self.nodes:
            system.check_types(
                context,
                error_collector=error_collector,
            )

    #
    # Simulation interface
    #

    # Inherits docstrings from SystemBase
    def eval_time_derivatives(self, root_context: DiagramContext) -> List[Array]:
        leaf_systems = [
            subctx.owning_system for subctx in root_context.continuous_subcontexts
        ]
        return [sys.eval_time_derivatives(root_context) for sys in leaf_systems]

    @property
    def mass_matrix(self) -> List[Array]:
        return [sys.mass_matrix for sys in self.leaf_systems]

    @property
    def has_mass_matrix(self) -> bool:
        return any(sys.has_mass_matrix for sys in self.leaf_systems)

    #
    # Event handling
    #
    @property
    def state_update_events(self) -> FlatEventCollection:
        assert self.parent is None, (
            "Can only get periodic events from top-level Diagram, not "
            f"{self.system_id} with parent {self.parent.system_id}"
        )
        events = sum(
            [sys.state_update_events for sys in self.leaf_systems],
            start=FlatEventCollection(),
        )
        return events

    @property
    def zero_crossing_events(self) -> DiagramEventCollection:
        assert self.parent is None, (
            "Can only get zero-crossing events from top-level Diagram, not "
            f"{self.system_id} with parent {self.parent.system_id}"
        )
        return DiagramEventCollection(
            OrderedDict(
                {sys.system_id: sys.zero_crossing_events for sys in self.leaf_systems}
            )
        )

    # Inherits docstrings from SystemBase
    def determine_active_guards(
        self, root_context: DiagramContext
    ) -> DiagramEventCollection:
        assert self.parent is None, (
            "Can only get zero-crossing events from top-level Diagram, not "
            f"{self.system_id} with parent {self.parent.system_id}"
        )
        return DiagramEventCollection(
            OrderedDict(
                {
                    sys.system_id: sys.determine_active_guards(root_context)
                    for sys in self.leaf_systems
                }
            )
        )

    # Inherits docstrings from SystemBase
    def eval_zero_crossing_updates(
        self,
        root_context: DiagramContext,
        events: DiagramEventCollection,
    ) -> dict[Hashable, LeafState]:
        substates = OrderedDict()
        for system_id, subctx in root_context.subcontexts.items():
            sys = subctx.owning_system
            substates[system_id] = sys.eval_zero_crossing_updates(root_context, events)

        return substates

    #
    # I/O ports
    #
    @property
    def _flat_callbacks(self) -> List[SystemCallback]:
        """Return a flat list of all SystemCallbacks in the Diagram."""
        return [cb for sys in self.nodes for cb in sys._flat_callbacks]

    @property
    def exported_input_ports(self):
        return self._input_port_map

    @property
    def exported_output_ports(self):
        return self._output_port_map

    def eval_subsystem_input_port(
        self, context: DiagramContext, port_locator: InputPortLocator
    ) -> Array:
        """Evaluate the input port for a child of this system given the root context.

        Args:
            context (ContextBase): root context for this system
            port_locator (InputPortLocator): tuple of (system, port_index) identifying
                the input port to evaluate

        Returns:
            Array: Value returned from evaluating the subsystem port.

        Raises:
            InputNotConnectedError: if the input port is not connected
        """

        is_exported = port_locator in self._input_port_map
        if is_exported:
            # The upstream source is an input to this whole Diagram; evaluate that
            # input port and use the result as the value for this one.
            port_index = self._input_port_map[port_locator]  # Diagram-level index
            return self.input_ports[port_index].eval(context)  # Return upstream value

        is_connected = port_locator in self.connection_map
        if is_connected:
            # The upstream source is an output port of one of this Diagram's child
            # subsystems; evaluate the upstream output.
            upstream_locator = self.connection_map[port_locator]

            # This will return the value of the upstream port
            return self.eval_subsystem_output_port(context, upstream_locator)

        block, port_index = port_locator
        raise InputNotConnectedError(
            system=block,
            port_index=port_index,
            port_direction="in",
            message=f"Input port {block.name}[{port_index}] is not connected",
        )

    def eval_subsystem_output_port(
        self, context: DiagramContext, port_locator: OutputPortLocator
    ) -> Array:
        """ "Evaluate the output port for a child of this system given the root context.

        Args:
            context (ContextBase): root context for this system
            port_locator (OutputPortLocator): tuple of (system, port_index) identifying
                the output port to evaluate

        Returns:
            Array: Value returned from evaluating the subsystem port.
        """
        system, port_index = port_locator
        port = system.output_ports[port_index]

        # During simulation all we should need to do is evaluate the port.
        if context.is_initialized:
            return port.eval(context)

        # If the context is not initialized, we have to determine the signal data type.
        # In the easy case, the port has a default value, so we can just use that.
        if port.default_value is not None:
            logger.debug(
                "Using default output value of %s for %s",
                port.default_value,
                port_locator[0].name,
            )
            return port.default_value

        logger.debug(
            "Evaluating output port %s for system %s. Context initialized: %s",
            port_locator,
            port_locator[0].name,
            context.is_initialized,
        )

        # If there is no default value, try to evaluate the port to pull a "template"
        # value with an appropriate data type from upstream.  This will return None if
        # the port is not yet connected (e.g. if its upstream is an exported input of)
        # a Diagram, so we can defer evaluation.

        # Try again to evaluate the port
        val = port.eval(context)
        logger.debug(
            "  ---> %s returns %s", (port_locator[0].name, port_locator[1]), val
        )

        # If there is still no value, the port is not connected to anything.
        # Post-initialization this would be an error, but pre-initialization
        # it may be the case that the upstream is an exported input port of
        # the Diagram, so we can defer evaluation. Expect the block that is
        # doing this to handle the UpstreamEvalError appropriately.
        if val is None:
            system_name = system.name_path_str
            logger.debug(
                "Upstream evaluation of %s.out[%s] returned None. Deferring evaluation.",
                system_name,
                port_index,
            )
            raise UpstreamEvalError(port_locator=(system, "out", port_index))
        return val

    def invalidate_output_caches(self):
        if not self._basic_output_cache.is_active():
            return

        self._basic_output_cache.invalidate()
        for system in self.nodes:
            system.invalidate_output_caches()

    #
    # System-level declarations (should be done via DiagramBuilder)
    #
    def export_input(self, locator: InputPortLocator, port_name: str) -> int:
        """Export a subsystem input port as a diagram-level input.

        This should typically only be called during construction by DiagramBuilder.
        The standard workflow will be to call export_input on the _builder_ object,
        which will automatically call this method on the Diagram once created.

        Args:
            locator (InputPortLocator): tuple of (system, port_index) identifying
                the input port to export
            port_name (str): name of the new exported input port

        Returns:
            int: index of the exported input port in the diagram input_ports list
        """
        diagram_port_index = self.declare_input_port(name=port_name)
        self._input_port_map[locator] = diagram_port_index

        return diagram_port_index

    def export_output(self, locator: OutputPortLocator, port_name: str) -> int:
        """Export a subsystem output port as a diagram-level output.

        This should typically only be called during construction by DiagramBuilder.
        The standard workflow will be to call export_input on the _builder_ object,
        which will automatically call this method on the Diagram once created.

        Args:
            locator (OutputPortLocator): tuple of (system, port_index) identifying
                the output port to export
            port_name (str): name of the new exported output port

        Returns:
            int: index of the exported output port in the diagram output_ports list
        """
        subsystem, subsystem_port_index = locator
        source_port = subsystem.output_ports[subsystem_port_index]
        diagram_port_index = self.declare_output_port(
            source_port.eval,
            name=port_name,
            prerequisites_of_calc=[source_port.ticket],
        )
        self._output_port_map[locator] = diagram_port_index
        self._inv_output_port_map[diagram_port_index] = locator

        return diagram_port_index

    #
    # Initialization
    #
    @property
    def context_factory(self) -> DiagramContextFactory:
        return DiagramContextFactory(self)

    @property
    def dependency_graph_factory(self) -> DiagramDependencyGraphFactory:
        return DiagramDependencyGraphFactory(self)

    def initialize_static_data(self, context: DiagramContext) -> DiagramContext:
        """Perform any system-specific static analysis."""
        for system in self.nodes:
            context = system.initialize_static_data(context)
        return context

    def _has_feedthrough(self, input_port_index: int, output_port_index: int) -> bool:
        """Check if there is a direct-feedthrough path from the input port to the output port.

        Internal function used by `get_feedthrough`.  Should not typically need to
        be called directly.
        """
        # TODO: Would this be simpler if the input port map was inverted?
        input_ids = []
        for locator, index in self._input_port_map.items():
            if index == input_port_index:
                input_ids.append(locator)

        input_ids = set(input_ids)

        # Search graph for a direct-feedthrough connection from the output_port
        # to the input_port.  Maintain a set of the output port identifiers that
        # are known to have a direct-feedthrough path to the output_port
        active_set: Set[OutputPortLocator] = set()
        active_set.add(self._inv_output_port_map[output_port_index])

        while len(active_set) > 0:
            sys, sys_output = active_set.pop()
            for u, v in sys.get_feedthrough():
                if v == sys_output:
                    curr_input_id = (sys, u)
                    if curr_input_id in input_ids:
                        # Found a direct-feedthrough path to the input_port
                        return True
                    elif curr_input_id in self.connection_map:
                        # Intermediate input port has a direct-feedthrough path to
                        # output_port. Add the upstream output port (if there
                        # is one) to the active set.
                        active_set.add(self.connection_map[curr_input_id])

        # If there are no intermediate output ports with a direct-feedthrough path
        # to the output port, there is no direct feedthrough from the input port
        return False

    # Inherits docstring from SystemBase.get_feedthrough
    def get_feedthrough(self) -> List[Tuple[int, int]]:
        if self.feedthrough_pairs is not None:
            return self.feedthrough_pairs

        pairs = []
        for u in range(self.num_input_ports):
            for v in range(self.num_output_ports):
                if self._has_feedthrough(u, v):
                    pairs.append((u, v))

        self.feedthrough_pairs = pairs
        return self.feedthrough_pairs

    def find_system_with_path(self, path: str | list[str]) -> SystemBase:
        if isinstance(path, str):
            path = path.split(".")

        def _find_in_children():
            for child in self.nodes:
                if child.name == path[0]:
                    if len(path) == 1:
                        return child
                    if isinstance(child, Diagram):
                        return child.find_system_with_path(path[1:])
                    return None
            return None

        if self.parent is None:
            return _find_in_children()

        if self.name == path[0] and len(path) == 1:
            return self

        return _find_in_children()

    def declare_dynamic_parameter(
        self, name: str, parameter: Array | Parameter
    ) -> None:
        """Declare a parameter for this system.

        Parameters:
            name (str): The name of the parameter.
            parameter (Parameter): The parameter object.
        """
        # Force the parameter to have the correct name, all diagram parameters
        # should be named.
        parameter.name = name
        # do not wrap in a new Parameter object like we do for LeafSystem because
        # the parameter could be used in multiple places.
        self._dynamic_parameters[name] = parameter

    # TODO: move this to context? it can't be called without the context first
    # being created (which creates the dependency graph)
    def check_no_algebraic_loops(self):
        """Check for algebraic loops in the diagram.

        This is a more or less direct port of the Drake method
        DiagramBuilder::ThrowIfAlgebraicLoopExists. Some comments are verbatim
        explanations of the algorithm implemented there.
        """

        # The nodes in the graph are the input/output ports defined as part of
        # the diagram's internal connections.  Ports that are not internally
        # connected cannot participate in a cycle at this level, so we don't include them
        # in the nodes set.
        nodes: Set[PortBase] = set()

        # For each `value` in `edges[key]`, the `key` directly influences `value`.
        edges: Mapping[PortBase, Set[PortBase]] = {}

        # Add the diagram's internal connections to the digraph nodes and edges
        for input_port_locator, output_port_locator in self.connection_map.items():
            # Directly using the port locator does not result in a unique identifier
            # since (sys, 0) represents both input port 0 and output port 0.  Instead,
            # use the port directly as a key, since it is a unique hashable object.
            input_system, input_index = input_port_locator
            input_port = input_system.input_ports[input_index]
            logger.debug(f"Adding locator {input_port} to nodes")
            nodes.add(input_port)

            output_system, output_index = output_port_locator
            output_port = output_system.output_ports[output_index]
            logger.debug(f"Adding locator {output_port} to nodes")
            nodes.add(output_port)

            if output_port not in edges:
                edges[output_port] = set()

            logger.debug(f"Adding edge[{output_port}] = {input_port}")
            edges[output_port].add(input_port)

        # Add more edges based on each System's direct feedthrough.
        # input -> output port iff there is direct feedthrough from input -> output
        # If a feedthrough edge refers to a port not in `nodes`, omit it because ports
        # that are not connected inside the diagram cannot participate in a cycle at
        # the level of this diagram (higher-level diagrams will test for cycles at
        # their level).
        for system in self.nodes:
            logger.debug(f"Checking feedthrough for system {system.name}")
            for input_index, output_index in system.get_feedthrough():
                input_port = system.input_ports[input_index]
                output_port = system.output_ports[output_index]
                logger.debug(f"Feedthrough from {input_port} to {output_port}")
                if input_port in nodes and output_port in nodes:
                    if input_port not in edges:
                        edges[input_port] = set()
                    edges[input_port].add(output_port)

        def _graph_has_cycle(
            node: PortBase,
            visited: Set[DirectedPortLocator],
            stack: List[DirectedPortLocator],
        ) -> bool:
            # Helper to do the algebraic loop test by depth-first search on the graph
            # to find cycles. Modifies `visited` and `stack` in place.

            logger.debug(f"Checking node {node}")

            assert node.directed_locator not in visited
            visited.add(node.directed_locator)

            if node in edges:
                assert node not in stack
                stack.append(node.directed_locator)
                edge_iter = edges[node]
                for target in edge_iter:
                    if target.directed_locator not in visited and _graph_has_cycle(
                        target, visited, stack
                    ):
                        logger.debug(f"Found cycle at {target}")
                        return True
                    elif target.directed_locator in stack:
                        logger.debug(f"Found target {target} in stack {stack}")
                        return True
                stack.pop()

            # If we get this far there is no cycle
            return False

        # Evaluate the graph for cycles
        visited: Set[DirectedPortLocator] = set()
        stack: List[DirectedPortLocator] = []
        for node in nodes:
            if node.directed_locator in visited:
                continue
            if _graph_has_cycle(node, visited, stack):
                raise AlgebraicLoopError(self.name, stack)

    @property
    def has_dirty_static_parameters(self) -> bool:
        """Check if any static parameters have been modified."""
        return any(n.has_dirty_static_parameters for n in self.nodes)

has_dirty_static_parameters: bool property

Check if any static parameters have been modified.

check_no_algebraic_loops()

Check for algebraic loops in the diagram.

This is a more or less direct port of the Drake method DiagramBuilder::ThrowIfAlgebraicLoopExists. Some comments are verbatim explanations of the algorithm implemented there.

Source code in collimator/framework/diagram.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
def check_no_algebraic_loops(self):
    """Check for algebraic loops in the diagram.

    This is a more or less direct port of the Drake method
    DiagramBuilder::ThrowIfAlgebraicLoopExists. Some comments are verbatim
    explanations of the algorithm implemented there.
    """

    # The nodes in the graph are the input/output ports defined as part of
    # the diagram's internal connections.  Ports that are not internally
    # connected cannot participate in a cycle at this level, so we don't include them
    # in the nodes set.
    nodes: Set[PortBase] = set()

    # For each `value` in `edges[key]`, the `key` directly influences `value`.
    edges: Mapping[PortBase, Set[PortBase]] = {}

    # Add the diagram's internal connections to the digraph nodes and edges
    for input_port_locator, output_port_locator in self.connection_map.items():
        # Directly using the port locator does not result in a unique identifier
        # since (sys, 0) represents both input port 0 and output port 0.  Instead,
        # use the port directly as a key, since it is a unique hashable object.
        input_system, input_index = input_port_locator
        input_port = input_system.input_ports[input_index]
        logger.debug(f"Adding locator {input_port} to nodes")
        nodes.add(input_port)

        output_system, output_index = output_port_locator
        output_port = output_system.output_ports[output_index]
        logger.debug(f"Adding locator {output_port} to nodes")
        nodes.add(output_port)

        if output_port not in edges:
            edges[output_port] = set()

        logger.debug(f"Adding edge[{output_port}] = {input_port}")
        edges[output_port].add(input_port)

    # Add more edges based on each System's direct feedthrough.
    # input -> output port iff there is direct feedthrough from input -> output
    # If a feedthrough edge refers to a port not in `nodes`, omit it because ports
    # that are not connected inside the diagram cannot participate in a cycle at
    # the level of this diagram (higher-level diagrams will test for cycles at
    # their level).
    for system in self.nodes:
        logger.debug(f"Checking feedthrough for system {system.name}")
        for input_index, output_index in system.get_feedthrough():
            input_port = system.input_ports[input_index]
            output_port = system.output_ports[output_index]
            logger.debug(f"Feedthrough from {input_port} to {output_port}")
            if input_port in nodes and output_port in nodes:
                if input_port not in edges:
                    edges[input_port] = set()
                edges[input_port].add(output_port)

    def _graph_has_cycle(
        node: PortBase,
        visited: Set[DirectedPortLocator],
        stack: List[DirectedPortLocator],
    ) -> bool:
        # Helper to do the algebraic loop test by depth-first search on the graph
        # to find cycles. Modifies `visited` and `stack` in place.

        logger.debug(f"Checking node {node}")

        assert node.directed_locator not in visited
        visited.add(node.directed_locator)

        if node in edges:
            assert node not in stack
            stack.append(node.directed_locator)
            edge_iter = edges[node]
            for target in edge_iter:
                if target.directed_locator not in visited and _graph_has_cycle(
                    target, visited, stack
                ):
                    logger.debug(f"Found cycle at {target}")
                    return True
                elif target.directed_locator in stack:
                    logger.debug(f"Found target {target} in stack {stack}")
                    return True
            stack.pop()

        # If we get this far there is no cycle
        return False

    # Evaluate the graph for cycles
    visited: Set[DirectedPortLocator] = set()
    stack: List[DirectedPortLocator] = []
    for node in nodes:
        if node.directed_locator in visited:
            continue
        if _graph_has_cycle(node, visited, stack):
            raise AlgebraicLoopError(self.name, stack)

check_types(context, error_collector=None)

Perform any system-specific static analysis.

Source code in collimator/framework/diagram.py
217
218
219
220
221
222
223
224
225
226
227
def check_types(
    self,
    context: DiagramContext,
    error_collector: ErrorCollector = None,
) -> None:
    """Perform any system-specific static analysis."""
    for system in self.nodes:
        system.check_types(
            context,
            error_collector=error_collector,
        )

declare_dynamic_parameter(name, parameter)

Declare a parameter for this system.

Parameters:

Name Type Description Default
name str

The name of the parameter.

required
parameter Parameter

The parameter object.

required
Source code in collimator/framework/diagram.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
def declare_dynamic_parameter(
    self, name: str, parameter: Array | Parameter
) -> None:
    """Declare a parameter for this system.

    Parameters:
        name (str): The name of the parameter.
        parameter (Parameter): The parameter object.
    """
    # Force the parameter to have the correct name, all diagram parameters
    # should be named.
    parameter.name = name
    # do not wrap in a new Parameter object like we do for LeafSystem because
    # the parameter could be used in multiple places.
    self._dynamic_parameters[name] = parameter

eval_subsystem_input_port(context, port_locator)

Evaluate the input port for a child of this system given the root context.

Parameters:

Name Type Description Default
context ContextBase

root context for this system

required
port_locator InputPortLocator

tuple of (system, port_index) identifying the input port to evaluate

required

Returns:

Name Type Description
Array Array

Value returned from evaluating the subsystem port.

Raises:

Type Description
InputNotConnectedError

if the input port is not connected

Source code in collimator/framework/diagram.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def eval_subsystem_input_port(
    self, context: DiagramContext, port_locator: InputPortLocator
) -> Array:
    """Evaluate the input port for a child of this system given the root context.

    Args:
        context (ContextBase): root context for this system
        port_locator (InputPortLocator): tuple of (system, port_index) identifying
            the input port to evaluate

    Returns:
        Array: Value returned from evaluating the subsystem port.

    Raises:
        InputNotConnectedError: if the input port is not connected
    """

    is_exported = port_locator in self._input_port_map
    if is_exported:
        # The upstream source is an input to this whole Diagram; evaluate that
        # input port and use the result as the value for this one.
        port_index = self._input_port_map[port_locator]  # Diagram-level index
        return self.input_ports[port_index].eval(context)  # Return upstream value

    is_connected = port_locator in self.connection_map
    if is_connected:
        # The upstream source is an output port of one of this Diagram's child
        # subsystems; evaluate the upstream output.
        upstream_locator = self.connection_map[port_locator]

        # This will return the value of the upstream port
        return self.eval_subsystem_output_port(context, upstream_locator)

    block, port_index = port_locator
    raise InputNotConnectedError(
        system=block,
        port_index=port_index,
        port_direction="in",
        message=f"Input port {block.name}[{port_index}] is not connected",
    )

eval_subsystem_output_port(context, port_locator)

"Evaluate the output port for a child of this system given the root context.

Parameters:

Name Type Description Default
context ContextBase

root context for this system

required
port_locator OutputPortLocator

tuple of (system, port_index) identifying the output port to evaluate

required

Returns:

Name Type Description
Array Array

Value returned from evaluating the subsystem port.

Source code in collimator/framework/diagram.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def eval_subsystem_output_port(
    self, context: DiagramContext, port_locator: OutputPortLocator
) -> Array:
    """ "Evaluate the output port for a child of this system given the root context.

    Args:
        context (ContextBase): root context for this system
        port_locator (OutputPortLocator): tuple of (system, port_index) identifying
            the output port to evaluate

    Returns:
        Array: Value returned from evaluating the subsystem port.
    """
    system, port_index = port_locator
    port = system.output_ports[port_index]

    # During simulation all we should need to do is evaluate the port.
    if context.is_initialized:
        return port.eval(context)

    # If the context is not initialized, we have to determine the signal data type.
    # In the easy case, the port has a default value, so we can just use that.
    if port.default_value is not None:
        logger.debug(
            "Using default output value of %s for %s",
            port.default_value,
            port_locator[0].name,
        )
        return port.default_value

    logger.debug(
        "Evaluating output port %s for system %s. Context initialized: %s",
        port_locator,
        port_locator[0].name,
        context.is_initialized,
    )

    # If there is no default value, try to evaluate the port to pull a "template"
    # value with an appropriate data type from upstream.  This will return None if
    # the port is not yet connected (e.g. if its upstream is an exported input of)
    # a Diagram, so we can defer evaluation.

    # Try again to evaluate the port
    val = port.eval(context)
    logger.debug(
        "  ---> %s returns %s", (port_locator[0].name, port_locator[1]), val
    )

    # If there is still no value, the port is not connected to anything.
    # Post-initialization this would be an error, but pre-initialization
    # it may be the case that the upstream is an exported input port of
    # the Diagram, so we can defer evaluation. Expect the block that is
    # doing this to handle the UpstreamEvalError appropriately.
    if val is None:
        system_name = system.name_path_str
        logger.debug(
            "Upstream evaluation of %s.out[%s] returned None. Deferring evaluation.",
            system_name,
            port_index,
        )
        raise UpstreamEvalError(port_locator=(system, "out", port_index))
    return val

export_input(locator, port_name)

Export a subsystem input port as a diagram-level input.

This should typically only be called during construction by DiagramBuilder. The standard workflow will be to call export_input on the builder object, which will automatically call this method on the Diagram once created.

Parameters:

Name Type Description Default
locator InputPortLocator

tuple of (system, port_index) identifying the input port to export

required
port_name str

name of the new exported input port

required

Returns:

Name Type Description
int int

index of the exported input port in the diagram input_ports list

Source code in collimator/framework/diagram.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def export_input(self, locator: InputPortLocator, port_name: str) -> int:
    """Export a subsystem input port as a diagram-level input.

    This should typically only be called during construction by DiagramBuilder.
    The standard workflow will be to call export_input on the _builder_ object,
    which will automatically call this method on the Diagram once created.

    Args:
        locator (InputPortLocator): tuple of (system, port_index) identifying
            the input port to export
        port_name (str): name of the new exported input port

    Returns:
        int: index of the exported input port in the diagram input_ports list
    """
    diagram_port_index = self.declare_input_port(name=port_name)
    self._input_port_map[locator] = diagram_port_index

    return diagram_port_index

export_output(locator, port_name)

Export a subsystem output port as a diagram-level output.

This should typically only be called during construction by DiagramBuilder. The standard workflow will be to call export_input on the builder object, which will automatically call this method on the Diagram once created.

Parameters:

Name Type Description Default
locator OutputPortLocator

tuple of (system, port_index) identifying the output port to export

required
port_name str

name of the new exported output port

required

Returns:

Name Type Description
int int

index of the exported output port in the diagram output_ports list

Source code in collimator/framework/diagram.py
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
def export_output(self, locator: OutputPortLocator, port_name: str) -> int:
    """Export a subsystem output port as a diagram-level output.

    This should typically only be called during construction by DiagramBuilder.
    The standard workflow will be to call export_input on the _builder_ object,
    which will automatically call this method on the Diagram once created.

    Args:
        locator (OutputPortLocator): tuple of (system, port_index) identifying
            the output port to export
        port_name (str): name of the new exported output port

    Returns:
        int: index of the exported output port in the diagram output_ports list
    """
    subsystem, subsystem_port_index = locator
    source_port = subsystem.output_ports[subsystem_port_index]
    diagram_port_index = self.declare_output_port(
        source_port.eval,
        name=port_name,
        prerequisites_of_calc=[source_port.ticket],
    )
    self._output_port_map[locator] = diagram_port_index
    self._inv_output_port_map[diagram_port_index] = locator

    return diagram_port_index

initialize_static_data(context)

Perform any system-specific static analysis.

Source code in collimator/framework/diagram.py
494
495
496
497
498
def initialize_static_data(self, context: DiagramContext) -> DiagramContext:
    """Perform any system-specific static analysis."""
    for system in self.nodes:
        context = system.initialize_static_data(context)
    return context

post_simulation_finalize()

Perform any post-simulation cleanup for this system.

Source code in collimator/framework/diagram.py
161
162
163
164
def post_simulation_finalize(self) -> None:
    """Perform any post-simulation cleanup for this system."""
    for system in self.nodes:
        system.post_simulation_finalize()

DiagramBuilder

Class for constructing block diagram systems.

The DiagramBuilder class is responsible for building a diagram by adding systems, connecting ports, and exporting inputs and outputs. It keeps track of the registered systems, input and output ports, and the connection map between input and output ports of the child systems.

Source code in collimator/framework/diagram_builder.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
class DiagramBuilder:
    """Class for constructing block diagram systems.

    The `DiagramBuilder` class is responsible for building a diagram by adding systems, connecting ports,
    and exporting inputs and outputs. It keeps track of the registered systems, input and output ports,
    and the connection map between input and output ports of the child systems.
    """

    def __init__(self):
        # Child input ports that are exported as diagram-level inputs
        self._input_port_ids: List[InputPortLocator] = []
        self._input_port_names: List[str] = []
        # Child output ports that are exported as diagram-level outputs
        self._output_port_ids: List[OutputPortLocator] = []
        self._output_port_names: List[str] = []

        # Connection map between input and output ports of the child systems
        self._connection_map: Mapping[InputPortLocator, OutputPortLocator] = {}

        # List of registered systems
        self._registered_systems: List[SystemBase] = []

        # Name lookup for input ports
        self._diagram_input_indices: Mapping[str, InputPortLocator] = {}

        # All input ports of child systems (for use in ensuring proper connectivity)
        self._all_input_ports: List[InputPortLocator] = []

        # Each DiagramBuilder can only be used to build a single diagram.  This is to
        # avoid creating multiple diagrams that reference the same LeafSystem. Doing so
        # may or may not actually lead to problems, since the LeafSystems themselves
        # should act like a collection of pure functions, but best practice is to have
        # each leaf system be fully unique.
        self._already_built = False

    @overload
    def add(self, system: SystemBase) -> SystemBase: ...

    @overload
    def add(self, system: SystemBase, *systems: SystemBase) -> List[SystemBase]: ...

    def add(self, *systems: SystemBase) -> List[SystemBase] | SystemBase:
        """Add one or more systems to the diagram.

        Args:
            *systems SystemBase:
                System(s) to add to the diagram.

        Returns:
            List[SystemBase] | SystemBase:
                The added system(s). Will return a single system if there is only
                a single system in the argument list.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the system is already registered.
            BuilderError: If the system name is not unique.
        """
        for system in systems:
            self._check_not_already_built()
            self._check_system_not_registered(system)
            self._check_system_name_is_unique(system)
            self._registered_systems.append(system)

            # Add the system's input ports to the list of all input ports
            # So that we can make sure they're all connected before building.
            self._all_input_ports.extend([port.locator for port in system.input_ports])

            logger.debug("Added system %s to DiagramBuilder", system.name)
            logger.debug(
                "    Registered systems: %s",
                [s.name for s in self._registered_systems],
            )
        build_recorder.add_block(self, systems)

        return systems[0] if len(systems) == 1 else systems

    def connect(self, src: OutputPort, dest: InputPort):
        """Connect an output port to an input port.

        The input port and output port must both belong to systems that have
        already been added to the diagram.  The input port must not already be
        connected to another output port.

        Args:
            src (OutputPort): The output port to connect.
            dest (InputPort): The input port to connect.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the source system is not registered.
            BuilderError: If the destination system is not registered.
            BuilderError: If the input port is already connected.
        """
        self._check_not_already_built()
        self._check_system_is_registered(src.system)
        self._check_system_is_registered(dest.system)
        self._check_input_not_connected(dest.locator)

        build_recorder.connect_ports(self, src, dest)

        self._connection_map[dest.locator] = src.locator

        logger.debug(
            f"Connected port {src.name} of system {src.system.name} to port {dest.name} of system {dest.system.name}"
        )

    def export_input(self, port: InputPort, name: str = None) -> int:
        """Export an input port of a child system as a diagram-level input.

        The input port must belong to a system that has already been added to the
        diagram. The input port must not already be connected to another output port.

        Args:
            port (InputPort): The input port to export.
            name (str, optional):
                The name to assign to the exported input port. If not provided, a
                unique name will be generated.

        Returns:
            int: The index (in the to-be-built diagram) of the exported input port.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the system is not registered.
            BuilderError: If the input port is already connected.
            BuilderError: If the input port name is not unique.
        """
        self._check_not_already_built()
        self._check_system_is_registered(port.system)
        self._check_input_not_connected(port.locator)

        if name is None:
            # Since the system names are unique, auto-generated port names are also unique
            # at the level of _this_ diagram (subsystems can have ports with the same name)
            name = f"{port.system.name}_{port.name}"
        elif name in self._diagram_input_indices:
            raise BuilderError(
                f"Input port name {name} is not unique",
                system=port.system,
                port_index=port.index,
                port_direction="in",
            )

        # Index at the diagram (not subsystem) level
        port_index = len(self._input_port_ids)
        self._input_port_ids.append(port.locator)
        self._input_port_names.append(name)

        self._diagram_input_indices[name] = port_index

        build_recorder.export_port(self, port.system, "input", port.index, name)

        return port_index

    def export_output(self, port: OutputPort, name: str = None) -> int:
        """Export an output port of a child system as a diagram-level output.

        The output port must belong to a system that has already been added to the
        diagram.

        Args:
            port (OutputPort): The output port to export.
            name (str, optional):
                The name to assign to the exported output port. If not provided, a
                unique name will be generated.

        Returns:
            int: The index (in the to-be-built diagram) of the exported output port.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the system is not registered.
            BuilderError: If the output port name is not unique.
        """
        self._check_not_already_built()
        self._check_system_is_registered(port.system)

        if name is None:
            # Since the system names are unique, auto-generated port names are also unique
            # at the level of _this_ diagram (subsystems can have ports with the same name)
            name = f"{port.system.name}_{port.name}"
        elif name in self._output_port_names:
            raise BuilderError(
                f"Output port name {name} is not unique",
                system=port.system,
                port_index=port.index,
                port_direction="out",
            )

        # Index at the diagram (not subsystem) level
        port_index = len(self._output_port_ids)
        self._output_port_ids.append(port.locator)
        self._output_port_names.append(name)

        build_recorder.export_port(self, port.system, "output", port.index, name)

        return port_index

    def _check_not_already_built(self):
        if self._already_built:
            raise BuilderError(
                "DiagramBuilder: build has already been called to "
                "create a diagram; this DiagramBuilder may no longer be used."
            )

    def _check_system_name_is_unique(self, system: SystemBase):
        if system.name in map(lambda s: s.name, self._registered_systems):
            raise SystemNameNotUniqueError(system)

    def _system_is_registered(self, system: SystemBase) -> bool:
        # return (system is not None) and (system in self._registered_systems)
        if system.system_id is None:  # system.__init__ is not done yet
            return False
        return system.system_id in map(lambda s: s.system_id, self._registered_systems)

    def _check_system_not_registered(self, system: SystemBase):
        if self._system_is_registered(system):
            raise BuilderError(
                f"System {system.name} is already registered",
                system=system,
            )

    def _check_system_is_registered(self, system: SystemBase):
        if not self._system_is_registered(system):
            raise BuilderError(
                f"System {system.name} is not registered",
                system=system,
            )

    def _check_input_not_connected(self, input_port_locator: InputPortLocator):
        if not (
            (input_port_locator not in self._input_port_ids)
            and (input_port_locator not in self._connection_map)
        ):
            system, port_index = input_port_locator
            raise BuilderError(
                f"Input port {port_index} for {system} is already connected",
                system=system,
                port_index=port_index,
                port_direction="in",
            )

    def _check_input_is_connected(self, input_port_locator: InputPortLocator):
        if not (
            (input_port_locator in self._input_port_ids)
            or (input_port_locator in self._connection_map)
        ):
            raise DisconnectedInputError(input_port_locator)

    def _check_contents_are_complete(self):
        # Make sure all the systems referenced in the builder attributes are registered

        # Check that systems and registered_systems have the same elements
        for system in self._registered_systems:
            self._check_system_is_registered(system)

        # Check that connection_map only refers to registered systems
        for (
            input_port_locator,
            output_port_locator,
        ) in self._connection_map.items():
            self._check_system_is_registered(input_port_locator[0])
            self._check_system_is_registered(output_port_locator[0])

        # Check that input_port_ids and output_port_ids only refer to registered systems
        for port_locator in [*self._input_port_ids, *self._output_port_ids]:
            self._check_system_is_registered(port_locator[0])

    def _check_ports_are_valid(self):
        for dst, src in self._connection_map.items():
            dst_sys, dst_idx = dst
            if (dst_idx < 0) or (dst_idx >= dst_sys.num_input_ports):
                raise BuilderError(
                    f"Input port index {dst_idx} is out of range "
                    f"(0-{dst_sys.num_input_ports-1})",
                    system=dst_sys,
                    port_index=dst_idx,
                    port_direction="in",
                )
            src_sys, src_idx = src
            if (src_idx < 0) or (src_idx >= src_sys.num_output_ports):
                raise BuilderError(
                    f"Output port index {src_idx} is out of range "
                    f"(0-{src_sys.num_output_ports-1})",
                    system=src_sys,
                    port_index=src_idx,
                    port_direction="out",
                )

    def build(
        self,
        name: str = "root",
        ui_id: str = None,
        parameters: dict[str, Parameter] = None,
    ) -> Diagram:
        """Builds a Diagram system with the specified name and system ID.

        Args:
            name (str, optional): The name of the diagram. Defaults to "root".
            ui_id (str, optional): The unique identifier for the diagram.
            parameters (dict[str, Parameter], optional):
                A dictionary of dynamic parameters to declare for the diagram.

        Returns:
            Diagram: The newly constructed diagram.

        Raises:
            EmptyDiagramError: If no systems are registered in the diagram.
            BuilderError: If the diagram has already been built.
            AlgebraicLoopError: If an algebraic loop is detected in the diagram.
            DisconnectedInputError: If an input port is not connected.
        """
        self._check_not_already_built()
        self._check_contents_are_complete()
        self._check_ports_are_valid()

        # Check that all internal input ports are connected
        for input_port_locator in self._input_port_ids:
            self._check_input_is_connected(input_port_locator)

        if len(self._registered_systems) == 0:
            raise EmptyDiagramError(name)

        diagram = Diagram(
            nodes=self._registered_systems,
            name=name,
            connection_map=self._connection_map,
            ui_id=ui_id,
        )

        build_recorder.build_diagram(self, diagram, parameters)

        if parameters:
            for name, parameter in parameters.items():
                diagram.declare_dynamic_parameter(name, parameter)
                diagram.instance_parameters.add(name)

        # Export diagram-level inputs
        for locator, port_name in zip(self._input_port_ids, self._input_port_names):
            diagram.export_input(locator, port_name)

        # Export diagram-level outputs
        assert len(self._output_port_ids) == len(self._output_port_names)
        for locator, port_name in zip(self._output_port_ids, self._output_port_names):
            diagram.export_output(locator, port_name)

        self._already_built = True  # Prevent further use of this builder
        return diagram

add(*systems)

add(system: SystemBase) -> SystemBase
add(system: SystemBase, *systems: SystemBase) -> List[SystemBase]

Add one or more systems to the diagram.

Parameters:

Name Type Description Default
*systems SystemBase

System(s) to add to the diagram.

()

Returns:

Type Description
List[SystemBase] | SystemBase

List[SystemBase] | SystemBase: The added system(s). Will return a single system if there is only a single system in the argument list.

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the system is already registered.

BuilderError

If the system name is not unique.

Source code in collimator/framework/diagram_builder.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def add(self, *systems: SystemBase) -> List[SystemBase] | SystemBase:
    """Add one or more systems to the diagram.

    Args:
        *systems SystemBase:
            System(s) to add to the diagram.

    Returns:
        List[SystemBase] | SystemBase:
            The added system(s). Will return a single system if there is only
            a single system in the argument list.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the system is already registered.
        BuilderError: If the system name is not unique.
    """
    for system in systems:
        self._check_not_already_built()
        self._check_system_not_registered(system)
        self._check_system_name_is_unique(system)
        self._registered_systems.append(system)

        # Add the system's input ports to the list of all input ports
        # So that we can make sure they're all connected before building.
        self._all_input_ports.extend([port.locator for port in system.input_ports])

        logger.debug("Added system %s to DiagramBuilder", system.name)
        logger.debug(
            "    Registered systems: %s",
            [s.name for s in self._registered_systems],
        )
    build_recorder.add_block(self, systems)

    return systems[0] if len(systems) == 1 else systems

build(name='root', ui_id=None, parameters=None)

Builds a Diagram system with the specified name and system ID.

Parameters:

Name Type Description Default
name str

The name of the diagram. Defaults to "root".

'root'
ui_id str

The unique identifier for the diagram.

None
parameters dict[str, Parameter]

A dictionary of dynamic parameters to declare for the diagram.

None

Returns:

Name Type Description
Diagram Diagram

The newly constructed diagram.

Raises:

Type Description
EmptyDiagramError

If no systems are registered in the diagram.

BuilderError

If the diagram has already been built.

AlgebraicLoopError

If an algebraic loop is detected in the diagram.

DisconnectedInputError

If an input port is not connected.

Source code in collimator/framework/diagram_builder.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def build(
    self,
    name: str = "root",
    ui_id: str = None,
    parameters: dict[str, Parameter] = None,
) -> Diagram:
    """Builds a Diagram system with the specified name and system ID.

    Args:
        name (str, optional): The name of the diagram. Defaults to "root".
        ui_id (str, optional): The unique identifier for the diagram.
        parameters (dict[str, Parameter], optional):
            A dictionary of dynamic parameters to declare for the diagram.

    Returns:
        Diagram: The newly constructed diagram.

    Raises:
        EmptyDiagramError: If no systems are registered in the diagram.
        BuilderError: If the diagram has already been built.
        AlgebraicLoopError: If an algebraic loop is detected in the diagram.
        DisconnectedInputError: If an input port is not connected.
    """
    self._check_not_already_built()
    self._check_contents_are_complete()
    self._check_ports_are_valid()

    # Check that all internal input ports are connected
    for input_port_locator in self._input_port_ids:
        self._check_input_is_connected(input_port_locator)

    if len(self._registered_systems) == 0:
        raise EmptyDiagramError(name)

    diagram = Diagram(
        nodes=self._registered_systems,
        name=name,
        connection_map=self._connection_map,
        ui_id=ui_id,
    )

    build_recorder.build_diagram(self, diagram, parameters)

    if parameters:
        for name, parameter in parameters.items():
            diagram.declare_dynamic_parameter(name, parameter)
            diagram.instance_parameters.add(name)

    # Export diagram-level inputs
    for locator, port_name in zip(self._input_port_ids, self._input_port_names):
        diagram.export_input(locator, port_name)

    # Export diagram-level outputs
    assert len(self._output_port_ids) == len(self._output_port_names)
    for locator, port_name in zip(self._output_port_ids, self._output_port_names):
        diagram.export_output(locator, port_name)

    self._already_built = True  # Prevent further use of this builder
    return diagram

connect(src, dest)

Connect an output port to an input port.

The input port and output port must both belong to systems that have already been added to the diagram. The input port must not already be connected to another output port.

Parameters:

Name Type Description Default
src OutputPort

The output port to connect.

required
dest InputPort

The input port to connect.

required

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the source system is not registered.

BuilderError

If the destination system is not registered.

BuilderError

If the input port is already connected.

Source code in collimator/framework/diagram_builder.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def connect(self, src: OutputPort, dest: InputPort):
    """Connect an output port to an input port.

    The input port and output port must both belong to systems that have
    already been added to the diagram.  The input port must not already be
    connected to another output port.

    Args:
        src (OutputPort): The output port to connect.
        dest (InputPort): The input port to connect.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the source system is not registered.
        BuilderError: If the destination system is not registered.
        BuilderError: If the input port is already connected.
    """
    self._check_not_already_built()
    self._check_system_is_registered(src.system)
    self._check_system_is_registered(dest.system)
    self._check_input_not_connected(dest.locator)

    build_recorder.connect_ports(self, src, dest)

    self._connection_map[dest.locator] = src.locator

    logger.debug(
        f"Connected port {src.name} of system {src.system.name} to port {dest.name} of system {dest.system.name}"
    )

export_input(port, name=None)

Export an input port of a child system as a diagram-level input.

The input port must belong to a system that has already been added to the diagram. The input port must not already be connected to another output port.

Parameters:

Name Type Description Default
port InputPort

The input port to export.

required
name str

The name to assign to the exported input port. If not provided, a unique name will be generated.

None

Returns:

Name Type Description
int int

The index (in the to-be-built diagram) of the exported input port.

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the system is not registered.

BuilderError

If the input port is already connected.

BuilderError

If the input port name is not unique.

Source code in collimator/framework/diagram_builder.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def export_input(self, port: InputPort, name: str = None) -> int:
    """Export an input port of a child system as a diagram-level input.

    The input port must belong to a system that has already been added to the
    diagram. The input port must not already be connected to another output port.

    Args:
        port (InputPort): The input port to export.
        name (str, optional):
            The name to assign to the exported input port. If not provided, a
            unique name will be generated.

    Returns:
        int: The index (in the to-be-built diagram) of the exported input port.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the system is not registered.
        BuilderError: If the input port is already connected.
        BuilderError: If the input port name is not unique.
    """
    self._check_not_already_built()
    self._check_system_is_registered(port.system)
    self._check_input_not_connected(port.locator)

    if name is None:
        # Since the system names are unique, auto-generated port names are also unique
        # at the level of _this_ diagram (subsystems can have ports with the same name)
        name = f"{port.system.name}_{port.name}"
    elif name in self._diagram_input_indices:
        raise BuilderError(
            f"Input port name {name} is not unique",
            system=port.system,
            port_index=port.index,
            port_direction="in",
        )

    # Index at the diagram (not subsystem) level
    port_index = len(self._input_port_ids)
    self._input_port_ids.append(port.locator)
    self._input_port_names.append(name)

    self._diagram_input_indices[name] = port_index

    build_recorder.export_port(self, port.system, "input", port.index, name)

    return port_index

export_output(port, name=None)

Export an output port of a child system as a diagram-level output.

The output port must belong to a system that has already been added to the diagram.

Parameters:

Name Type Description Default
port OutputPort

The output port to export.

required
name str

The name to assign to the exported output port. If not provided, a unique name will be generated.

None

Returns:

Name Type Description
int int

The index (in the to-be-built diagram) of the exported output port.

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the system is not registered.

BuilderError

If the output port name is not unique.

Source code in collimator/framework/diagram_builder.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def export_output(self, port: OutputPort, name: str = None) -> int:
    """Export an output port of a child system as a diagram-level output.

    The output port must belong to a system that has already been added to the
    diagram.

    Args:
        port (OutputPort): The output port to export.
        name (str, optional):
            The name to assign to the exported output port. If not provided, a
            unique name will be generated.

    Returns:
        int: The index (in the to-be-built diagram) of the exported output port.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the system is not registered.
        BuilderError: If the output port name is not unique.
    """
    self._check_not_already_built()
    self._check_system_is_registered(port.system)

    if name is None:
        # Since the system names are unique, auto-generated port names are also unique
        # at the level of _this_ diagram (subsystems can have ports with the same name)
        name = f"{port.system.name}_{port.name}"
    elif name in self._output_port_names:
        raise BuilderError(
            f"Output port name {name} is not unique",
            system=port.system,
            port_index=port.index,
            port_direction="out",
        )

    # Index at the diagram (not subsystem) level
    port_index = len(self._output_port_ids)
    self._output_port_ids.append(port.locator)
    self._output_port_names.append(name)

    build_recorder.export_port(self, port.system, "output", port.index, name)

    return port_index

DiagramContext dataclass

Bases: ContextBase

Source code in collimator/framework/context.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
@dataclasses.dataclass(frozen=True)
class DiagramContext(ContextBase):
    subcontexts: OrderedDict[Hashable, LeafContext] = dataclasses.field(
        default_factory=OrderedDict
    )

    def _check_key(self, key: Hashable) -> None:
        assert key == self.owning_system.system_id or key in self.subcontexts, (
            f"System ID {key} not found in DiagramContext {self}.\nIf this ID "
            "references an intermediate diagram, note that intermediate diagrams do "
            "not have associated contexts. Only the root diagram and leaf systems have "
            "contexts."
        )

    def __getitem__(self, key: Hashable) -> LeafContext:
        self._check_key(key)
        if key == self.owning_system.system_id:
            return self
        return self.subcontexts[key]

    def find_context_with_path(self, path: list[str]) -> ContextBase:
        system = self.owning_system.find_system_with_path(path)
        if system is None:
            raise ValueError(
                f"No system with path {path} found in {self.owning_system}"
            )
        return self[system.system_id]

    def with_subcontext(self, key: Hashable, ctx: LeafContext) -> DiagramContext:
        self._check_key(key)
        subcontexts = self.subcontexts.copy()
        subcontexts[key] = ctx
        return dataclasses.replace(self, subcontexts=subcontexts)

    #
    # Simulation interface
    #
    @property
    def state(self) -> Mapping[Hashable, LeafState]:
        return OrderedDict(
            {system_id: subctx.state for system_id, subctx in self.subcontexts.items()}
        )

    @property
    def continuous_subcontexts(self) -> List[LeafContext]:
        return [
            subctx
            for subctx in self.subcontexts.values()
            if subctx.has_continuous_state
        ]

    @property
    def continuous_state(self) -> List[Array]:
        return [subctx.continuous_state for subctx in self.continuous_subcontexts]

    def with_continuous_state(self, sub_xcs: List[Array]) -> DiagramContext:
        # Shallow copy the subcontexts - only modify the ones that have continuous states
        new_subcontexts = self.subcontexts.copy()
        for subctx, sub_xc in zip(self.continuous_subcontexts, sub_xcs):
            new_subcontexts[subctx.system_id] = subctx.with_continuous_state(sub_xc)
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    @property
    def num_continuous_states(self) -> int:
        return sum(
            [subctx.num_continuous_states for subctx in self.subcontexts.values()]
        )

    @property
    def has_continuous_state(self) -> bool:
        return self.num_continuous_states > 0

    @property
    def discrete_subcontexts(self) -> List[LeafContext]:
        return [
            subctx for subctx in self.subcontexts.values() if subctx.has_discrete_state
        ]

    @property
    def discrete_state(self) -> List[List[Array]]:
        return [subctx.discrete_state for subctx in self.discrete_subcontexts]

    def with_discrete_state(self, sub_xds: List[List[Array]]) -> DiagramContext:
        # Shallow copy the subcontexts - only modify the ones that have discrete states
        new_subcontexts = self.subcontexts.copy()
        for subctx, sub_xd in zip(self.discrete_subcontexts, sub_xds):
            new_subcontexts[subctx.system_id] = subctx.with_discrete_state(sub_xd)
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    @property
    def num_discrete_states(self) -> int:
        return sum([subctx.num_discrete_states for subctx in self.subcontexts.values()])

    @property
    def has_discrete_state(self) -> bool:
        return self.num_discrete_states > 0

    @property
    def mode_subcontexts(self) -> List[LeafContext]:
        return [subctx for subctx in self.subcontexts.values() if subctx.has_mode]

    @property
    def mode(self) -> List[int]:
        return [subctx.mode for subctx in self.mode_subcontexts]

    def with_mode(self, sub_modes: List[int]) -> DiagramContext:
        new_subcontexts = self.subcontexts.copy()
        for subctx, sub_mode in zip(self.mode_subcontexts, sub_modes):
            new_subcontexts[subctx.system_id] = subctx.with_mode(sub_mode)
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    @property
    def has_mode(self) -> bool:
        return any([subctx.has_mode for subctx in self.subcontexts.values()])

    def with_state(self, sub_states: Mapping[Hashable, LeafState]) -> DiagramContext:
        new_subcontexts = OrderedDict()
        for system_id, sub_state in sub_states.items():
            new_subcontexts[system_id] = dataclasses.replace(
                self.subcontexts[system_id], state=sub_state
            )
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    def with_new_state(self) -> ContextBase:
        new_subcontexts = OrderedDict()
        for system_id, subctx in self.subcontexts.items():
            new_subcontexts[system_id] = subctx.with_new_state()
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    def with_updated_parameters(self) -> ContextBase:
        new_parameters = {
            name: param.get()
            for name, param in self.owning_system.dynamic_parameters.items()
        }
        new_subcontexts = {}
        for k, v in self.subcontexts.items():
            new_subcontexts[k] = v.with_updated_parameters()

        return dataclasses.replace(
            self, subcontexts=new_subcontexts, parameters=new_parameters
        )

    def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
        """Create a copy of this context, replacing only the specified parameters."""
        parameters = {**self.parameters}

        # First validate that all parameters exist and are dynamic
        for name, value in new_parameters.items():
            if name not in self.owning_system.dynamic_parameters:
                raise ValueError(
                    f"Parameter {name} not found in {self.owning_system.name}"
                )
            param = self.owning_system.dynamic_parameters[name]

            if param.static_dependents:
                static_dependents = ", ".join(
                    [f"{dep.system.name}" for dep in param.static_dependents]
                )
                raise StaticParameterError(
                    f"Parameter {name} is used in static parameters"
                    " and cannot be updated dynamically. Please create a new context."
                    f" Static dependents in blocks: {static_dependents}"
                )

        for name, value in new_parameters.items():
            self.owning_system.dynamic_parameters[name].set(value)
            parameters[name] = value

        context = dataclasses.replace(self, parameters=parameters)
        return context.with_updated_parameters()

with_parameters(new_parameters)

Create a copy of this context, replacing only the specified parameters.

Source code in collimator/framework/context.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
    """Create a copy of this context, replacing only the specified parameters."""
    parameters = {**self.parameters}

    # First validate that all parameters exist and are dynamic
    for name, value in new_parameters.items():
        if name not in self.owning_system.dynamic_parameters:
            raise ValueError(
                f"Parameter {name} not found in {self.owning_system.name}"
            )
        param = self.owning_system.dynamic_parameters[name]

        if param.static_dependents:
            static_dependents = ", ".join(
                [f"{dep.system.name}" for dep in param.static_dependents]
            )
            raise StaticParameterError(
                f"Parameter {name} is used in static parameters"
                " and cannot be updated dynamically. Please create a new context."
                f" Static dependents in blocks: {static_dependents}"
            )

    for name, value in new_parameters.items():
        self.owning_system.dynamic_parameters[name].set(value)
        parameters[name] = value

    context = dataclasses.replace(self, parameters=parameters)
    return context.with_updated_parameters()

DiscreteUpdateEvent dataclass

Bases: Event

Event representing a discrete update in a hybrid system.

Source code in collimator/framework/event.py
285
286
287
288
289
290
291
292
293
294
295
296
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class DiscreteUpdateEvent(Event):
    """Event representing a discrete update in a hybrid system."""

    # Supersede type hints in Event with the specific signature for discrete updates
    callback: Callable[[ContextBase], Array] = None
    passthrough: Callable[[ContextBase], Array] = None

    # Inherits docstring from Event. This is only needed to specialize type hints.
    def handle(self, context: ContextBase) -> Array:
        return super().handle(context)

DtypeMismatchError

Bases: StaticError

Block parameters or input/outputs have mismatched dtypes.

Source code in collimator/framework/error.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class DtypeMismatchError(StaticError):
    """Block parameters or input/outputs have mismatched dtypes."""

    def __init__(self, expected_dtype=None, actual_dtype=None, **kwargs):
        super().__init__(**kwargs)
        self.expected_dtype = expected_dtype
        self.actual_dtype = actual_dtype

    def __str__(self):
        if self.expected_dtype or self.actual_dtype:
            return (
                f"Data type mismatch: "
                f"expected {self.expected_dtype}, got {self.actual_dtype}"
                + self._context_info()
            )
        return f"Dtype mismatch{self._context_info()}"

ErrorCollector

Tool used to collect errors related to users model specification. Errors related to user model specification are identified during model static analysis, e.g. context creation, type checking, etc.

An instance of this tool can be created, and then passed down a tree of function calls to collect errors found any where in the tree. Locally in the tree it can be determined whether it is ok to continue or not. This tool enables collecting errors up until the point when continuation is no longer possible.

Note: this latter behavior, where sometimes there is early exit desired, and all other "pipeline" operations are "nullified", might better be implemented using pymonad:Either class.

Source code in collimator/framework/error.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
class ErrorCollector:
    """
    Tool used to collect errors related to users model specification.
    Errors related to user model specification are identified during
    model static analysis, e.g. context creation, type checking, etc.

    An instance of this tool can be created, and then passed down a
    tree of function calls to collect errors found any where in
    the tree. Locally in the tree it can be determined whether it is
    ok to continue or not. This tool enables collecting errors up until
    the point when continuation is no longer possible.

    Note: this latter behavior, where sometimes there is early exit desired,
    and all other "pipeline" operations are "nullified", might better be
    implemented using pymonad:Either class.
    """

    def __init__(self):
        self._disable_collection = False
        self._parent: Optional["ErrorCollector"] = None
        self.errors: list[BaseException] = []

    def add_error(self, error: BaseException):
        """Add an error to the collection."""

        if self._parent is not None:
            self._parent.add_error(error)
            return

        if not self._disable_collection:
            self.errors.append(error)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # Return values: True to suppress the exception, False to propagate it

        if exc_type is not None:
            if self._parent is not None:
                self._parent.add_error(exc_value)
                return True

            self.add_error(exc_value)
            return False

        return True

    @classmethod
    def context(cls, parent: "ErrorCollector" = None) -> "ErrorCollector":
        """A context manager convenience to use when tracing errors.

        Use as:
        ```
        with ErrorCollector.trace(error_context) as ec:
            ...
        ```

        If the parent context is None, then exceptions will pass through without
        being collected. Else, exceptions will be collected in the parent context.
        """

        if parent is None:
            ctx = cls()
            ctx._disable_collection = True
            return ctx

        ctx = cls()
        ctx._parent = parent
        return ctx

add_error(error)

Add an error to the collection.

Source code in collimator/framework/error.py
325
326
327
328
329
330
331
332
333
def add_error(self, error: BaseException):
    """Add an error to the collection."""

    if self._parent is not None:
        self._parent.add_error(error)
        return

    if not self._disable_collection:
        self.errors.append(error)

context(parent=None) classmethod

A context manager convenience to use when tracing errors.

Use as:

with ErrorCollector.trace(error_context) as ec:
    ...

If the parent context is None, then exceptions will pass through without being collected. Else, exceptions will be collected in the parent context.

Source code in collimator/framework/error.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
@classmethod
def context(cls, parent: "ErrorCollector" = None) -> "ErrorCollector":
    """A context manager convenience to use when tracing errors.

    Use as:
    ```
    with ErrorCollector.trace(error_context) as ec:
        ...
    ```

    If the parent context is None, then exceptions will pass through without
    being collected. Else, exceptions will be collected in the parent context.
    """

    if parent is None:
        ctx = cls()
        ctx._disable_collection = True
        return ctx

    ctx = cls()
    ctx._parent = parent
    return ctx

EventCollection

A collection of events owned by a system.

Users should not need to interact with these objects directly. They are intended to be used internally by the simulation framework for handling events in hybrid system simulation.

These contain callback functions that update the context in various ways when the event is triggered. There will be different "collections" for each trigger type in simulation (e.g. periodic vs zero-crossing). Within the collections, events are broken out by function (e.g. discrete vs unrestricted updates).

There are separate implementations for leaf and diagram systems, where the DiagramCEventCollection preserves the tree structure of the underlying Diagram. However, the interface in both cases is the same and is identical to the interface defined by EventCollection.

Source code in collimator/framework/event.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
class EventCollection(metaclass=abc.ABCMeta):
    """A collection of events owned by a system.

    Users should not need to interact with these objects directly. They are intended
    to be used internally by the simulation framework for handling events in hybrid
    system simulation.

    These contain callback functions that update the context in various ways
    when the event is triggered. There will be different "collections" for each
    trigger type in simulation (e.g. periodic vs zero-crossing). Within the
    collections, events are broken out by function (e.g. discrete vs unrestricted
    updates).

    There are separate implementations for leaf and diagram systems, where the
    DiagramCEventCollection preserves the tree structure of the underlying
    Diagram. However, the interface in both cases is the same and is identical to
    the interface defined by EventCollection.
    """

    @abc.abstractmethod
    def __getitem__(self, key: Hashable) -> EventCollection:
        pass

    @property
    @abc.abstractmethod
    def events(self) -> List[Event]:
        pass

    @property
    @abc.abstractmethod
    def num_events(self) -> int:
        pass

    @property
    def has_events(self) -> bool:
        return self.num_events > 0

    def __iter__(self):
        return iter(self.events)

    def __len__(self):
        return self.num_events

    @abc.abstractmethod
    def activate(self, activation_fn) -> EventCollection:
        pass

    def mark_all_active(self) -> EventCollection:
        return self.activate(lambda _: True)

    def mark_all_inactive(self) -> EventCollection:
        return self.activate(lambda _: False)

    @property
    def num_active(self) -> int:
        def _get_active(event_data: EventData) -> bool:
            return event_data.active

        active_tree = tree_util.tree_map(
            _get_active,
            self,
            is_leaf=is_event_data,
        )
        return sum(tree_util.tree_leaves(active_tree))

    @property
    def has_active(self) -> bool:
        return self.num_active > 0

    @property
    def has_triggered(self) -> bool:
        def _get_triggered(event_data: EventData) -> bool:
            return event_data.active & event_data.triggered

        triggered_tree = tree_util.tree_map(
            _get_triggered,
            self,
            is_leaf=is_event_data,
        )
        return sum(tree_util.tree_leaves(triggered_tree)) > 0

    @property
    @abc.abstractmethod
    def terminal_events(self) -> EventCollection:
        pass

    @property
    def has_terminal_events(self):
        return self.terminal_events.has_events

    @property
    def has_active_terminal(self) -> bool:
        return self.terminal_events.has_triggered

    def pprint(self, output=print):
        output(self._pprint_helper().strip())

    def _pprint_helper(self, prefix="") -> str:
        s = f"{prefix}|-- \n"
        if len(self.events) > 0:
            s += f"{prefix}    Events:\n"
            for event in self.events:
                s += f"{prefix}    |  {event}\n"
        return s

    def __repr__(self) -> str:
        s = f"{type(self).__name__}("
        if self.has_events:
            s += f"discrete_update: {self.events} "
        s += ")"

        return s

IntegerTime

Class for managing conversion between decimal and integer time.

Source code in collimator/framework/event.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class IntegerTime:
    """Class for managing conversion between decimal and integer time."""

    # TODO: Can we use this directly as an int?  Would need to implement __add__,
    # __sub__, etc.  Also, comparisons and floor divide.  Would make the code in
    # Simulator cleaner, but dealing with JAX tracers in `where` and the like
    # might be difficult.  See commit 043c8f757 for a previous attempt.

    #
    # Class variables
    #
    time_scale = DEFAULT_TIME_SCALE  # int -> float conversion factor
    inv_time_scale = 1 / time_scale  # float -> int conversion factor

    # Type of the integer time representation. Defaults to x64 unless explicitly disabled.
    dtype: DTypeLike = cnp.intx

    # Largest time value representable by IntegerTime.dtype
    max_int_time = cnp.iinfo(dtype).max

    # Floating point representation of max_int_time
    max_float_time = cnp.asarray(max_int_time * time_scale, dtype=dtype)

    #
    # Class methods
    #
    @classmethod
    def set_scale(cls, time_scale: float):
        cls.time_scale = time_scale
        cls.inv_time_scale = 1 / time_scale
        cls.max_float_time = cnp.asarray(cls.max_int_time * time_scale, dtype=cls.dtype)

    @classmethod
    def set_default_scale(cls):
        cls.set_scale(DEFAULT_TIME_SCALE)

    @classmethod
    def from_decimal(cls, time: float) -> int:
        """Convert a floating-point time to an integer time."""
        # First limit to the max value to avoid overflow with inf or very large values.
        time = cnp.minimum(time, cls.max_float_time)
        return cnp.asarray(time * cls.inv_time_scale, dtype=cls.dtype)

    @classmethod
    def as_decimal(cls, time: int) -> float:
        """Convert an integer time to a floating-point time."""
        return time * cls.time_scale

as_decimal(time) classmethod

Convert an integer time to a floating-point time.

Source code in collimator/framework/event.py
89
90
91
92
@classmethod
def as_decimal(cls, time: int) -> float:
    """Convert an integer time to a floating-point time."""
    return time * cls.time_scale

from_decimal(time) classmethod

Convert a floating-point time to an integer time.

Source code in collimator/framework/event.py
82
83
84
85
86
87
@classmethod
def from_decimal(cls, time: float) -> int:
    """Convert a floating-point time to an integer time."""
    # First limit to the max value to avoid overflow with inf or very large values.
    time = cnp.minimum(time, cls.max_float_time)
    return cnp.asarray(time * cls.inv_time_scale, dtype=cls.dtype)

LeafContext dataclass

Bases: ContextBase

Source code in collimator/framework/context.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
@dataclasses.dataclass(frozen=True)
class LeafContext(ContextBase):
    state: LeafState = None

    @property
    def system_id(self) -> Hashable:
        return self.owning_system.system_id

    def __getitem__(self, key: Hashable) -> LeafContext:
        """Dummy indexing for compatibility with DiagramContexts, returning self."""
        assert key == self.system_id, f"Attempting to get subcontext {key} from {self}"
        return self

    def with_subcontext(self, key: Hashable, ctx: LeafContext) -> LeafContext:
        """Dummy replacement for compatibility with DiagramContexts, returning ctx."""
        assert (
            key == self.system_id
        ), f"System ID {key} does not match leaf ID {self.system_id}"
        assert (
            key == ctx.system_id
        ), f"System ID {key} does not match leaf ID {ctx.system_id}"
        return ctx

    def __repr__(self) -> str:
        return f"{type(self).__name__}(sys={self.system_id})"

    def with_state(self, state: LeafState) -> LeafContext:
        return dataclasses.replace(self, state=state)

    @property
    def continuous_state(self) -> LeafStateComponent:
        return self.state.continuous_state

    def with_continuous_state(self, value: LeafStateComponent) -> LeafContext:
        return dataclasses.replace(self, state=self.state.with_continuous_state(value))

    @property
    def num_continuous_states(self) -> int:
        return self.state.num_continuous_states

    @property
    def has_continuous_state(self) -> bool:
        return self.state.has_continuous_state

    @property
    def discrete_state(self) -> LeafStateComponent:
        return self.state.discrete_state

    def with_discrete_state(self, value: LeafStateComponent) -> LeafContext:
        return dataclasses.replace(self, state=self.state.with_discrete_state(value))

    @property
    def num_discrete_states(self) -> int:
        return self.state.num_discrete_states

    @property
    def has_discrete_state(self) -> bool:
        return self.state.has_discrete_state

    @property
    def mode(self) -> int:
        return self.state.mode

    @property
    def has_mode(self) -> bool:
        return self.state.has_mode

    def with_mode(self, value: int) -> LeafContext:
        return dataclasses.replace(self, state=self.state.with_mode(value))

    @property
    def cache(self) -> tuple[Array]:
        return self.state.cache

    @property
    def num_cached_values(self) -> int:
        return self.state.num_cached_values

    @property
    def has_cache(self) -> bool:
        return self.state.has_cache

    def with_cached_value(self, index: int, value: Array) -> LeafContext:
        return dataclasses.replace(
            self, state=self.state.with_cached_value(index, value)
        )

    def with_updated_parameters(self) -> ContextBase:
        params = self.owning_system.dynamic_parameters
        new_parameters = {}
        for name, param in params.items():
            new_parameters[name] = param.get()
        return dataclasses.replace(self, parameters=new_parameters)

    def with_new_state(self) -> ContextBase:
        return dataclasses.replace(self, state=self.owning_system.create_state())

    def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
        """Create a copy of this context, replacing only the specified parameters."""
        parameters = {**self.parameters}
        for name, value in new_parameters.items():
            param = self.owning_system.dynamic_parameters[name]
            param.set(value)
            parameters[name] = param.get()
        return dataclasses.replace(self, parameters=parameters)

__getitem__(key)

Dummy indexing for compatibility with DiagramContexts, returning self.

Source code in collimator/framework/context.py
217
218
219
220
def __getitem__(self, key: Hashable) -> LeafContext:
    """Dummy indexing for compatibility with DiagramContexts, returning self."""
    assert key == self.system_id, f"Attempting to get subcontext {key} from {self}"
    return self

with_parameters(new_parameters)

Create a copy of this context, replacing only the specified parameters.

Source code in collimator/framework/context.py
306
307
308
309
310
311
312
313
def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
    """Create a copy of this context, replacing only the specified parameters."""
    parameters = {**self.parameters}
    for name, value in new_parameters.items():
        param = self.owning_system.dynamic_parameters[name]
        param.set(value)
        parameters[name] = param.get()
    return dataclasses.replace(self, parameters=parameters)

with_subcontext(key, ctx)

Dummy replacement for compatibility with DiagramContexts, returning ctx.

Source code in collimator/framework/context.py
222
223
224
225
226
227
228
229
230
def with_subcontext(self, key: Hashable, ctx: LeafContext) -> LeafContext:
    """Dummy replacement for compatibility with DiagramContexts, returning ctx."""
    assert (
        key == self.system_id
    ), f"System ID {key} does not match leaf ID {self.system_id}"
    assert (
        key == ctx.system_id
    ), f"System ID {key} does not match leaf ID {ctx.system_id}"
    return ctx

LeafState dataclass

Container for state information for a leaf system.

Attributes:

Name Type Description
name str

Name of the leaf system that owns this state.

continuous_state LeafStateComponent

Continuous state of the system, i.e. the component of state that evolves in continuous time. If the system has no continuous state, this will be None.

discrete_state LeafStateComponent

Discrete state of the system, i.e. one or more components of state that do not change continuously with ime (not necessarily discrete-valued). If the system has no discrete state, this will be None.

mode int

An integer value indicating the current "mode", "stage", or discrete-valued state component of the system. Used for finite state machines or multi-stage hybrid systems. If the system has no mode, this will be None.

cache tuple[LeafStateComponent]

The current values of sample-and-hold outputs from the system. In a pure discrete system these would not be state components (just results of feedthrough computations), but in a hybrid or multirate system they act as discrete state from the perspective of continuous or asynchronous discrete components of the system. Hence, they are stored in the state, but are maintained separately from the normal internal state of the system.

Notes

(1) This class is immutable. To modify a LeafState, use the with_* methods.

(2) The type annotations for state components are LeafStateComponent, which is a union of array, tuple, and named tuple. The most common case is arrays, but this allows for more flexibility in defining state components, e.g. a second-order system can define a named tuple of generalized coordinates and velocities rather than concatenating into a single array.

Source code in collimator/framework/state.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@dataclasses.dataclass(frozen=True)
class LeafState:
    """Container for state information for a leaf system.

    Attributes:
        name (str):
            Name of the leaf system that owns this state.
        continuous_state (LeafStateComponent):
            Continuous state of the system, i.e. the component of state that evolves in
            continuous time. If the system has no continuous state, this will be None.
        discrete_state (LeafStateComponent):
            Discrete state of the system, i.e. one or more components of state that do
            not change continuously with ime (not necessarily discrete-_valued_). If
            the system has no discrete state, this will be None.
        mode (int):
            An integer value indicating the current "mode", "stage", or discrete-valued
            state component of the system.  Used for finite state machines or
            multi-stage hybrid systems.  If the system has no mode, this will be None.
        cache (tuple[LeafStateComponent]):
            The current values of sample-and-hold outputs from the system.  In a pure
            discrete system these would not be state components (just results of
            feedthrough computations), but in a hybrid or multirate system they act as
            discrete state from the perspective of continuous or asynchronous discrete
            components of the system.  Hence, they are stored in the state, but are
            maintained separately from the normal internal state of the system.

    Notes:
        (1) This class is immutable.  To modify a LeafState, use the `with_*` methods.

        (2) The type annotations for state components are LeafStateComponent, which is
        a union of array, tuple, and named tuple. The most common case is arrays, but
        this allows for more flexibility in defining state components, e.g. a
        second-order system can define a named tuple of generalized coordinates and
        velocities rather than concatenating into a single array.
    """

    name: str = None
    continuous_state: LeafStateComponent = None
    discrete_state: LeafStateComponent = None
    mode: int = None
    cache: tuple[Array] = None

    def __repr__(self) -> str:
        states = []
        if self.continuous_state is not None:
            states.append(f"xc={self.continuous_state}")
        if self.discrete_state is not None:
            states.append(f"xd={self.discrete_state}")
        if self.mode is not None:
            states.append(f"s={self.mode}")
        return f"{type(self).__name__}({', '.join(states)})"

    def with_continuous_state(self, value: LeafStateComponent) -> LeafState:
        """Create a copy of this LeafState with the continuous state replaced."""
        if value is not None and self.continuous_state is not None:
            value = tree_util.tree_map(self._reshape_like, value, self.continuous_state)

        return dataclasses.replace(self, continuous_state=value)

    def _component_size(self, component: LeafStateComponent) -> int:
        if component is None:
            return 0
        if isinstance(component, tuple):
            # return sum(x.size for x in component)
            return len(component)
        return component.size

    def _reshape_like(self, new_value: Array, current_value: Array) -> Array:
        """Helper function for tree-mapped type conversions.

        Ensures that the new components are array-like and have the same shape as
        the existing state to preserve PyTree structure.
        """
        return reshape(new_value, current_value.shape)

    @property
    def num_continuous_states(self) -> int:
        return self._component_size(self.continuous_state)

    @property
    def has_continuous_state(self) -> bool:
        return self.num_continuous_states > 0

    def with_discrete_state(self, value: LeafStateComponent) -> LeafState:
        """Create a copy of this LeafState with the discrete state replaced."""
        if value is not None and self.discrete_state is not None:
            value = tree_util.tree_map(self._reshape_like, value, self.discrete_state)

        return dataclasses.replace(self, discrete_state=value)

    @property
    def num_discrete_states(self) -> int:
        return self._component_size(self.discrete_state)

    @property
    def has_discrete_state(self) -> bool:
        return self.num_discrete_states > 0

    def with_mode(self, value: int) -> LeafState:
        """Create a copy of this LeafState with the mode replaced."""
        return dataclasses.replace(self, mode=value)

    @property
    def has_mode(self) -> bool:
        return self.mode is not None

    def with_cached_value(self, index: int, value: Array) -> LeafState:
        """Create a copy of this LeafState with the specified cache value replaced."""
        cache = list(self.cache)
        cache[index] = value
        return dataclasses.replace(self, cache=tuple(cache))

    def has_cache(self) -> bool:
        return self.cache is not None

    def num_cached_values(self) -> int:
        return len(self.cache)

with_cached_value(index, value)

Create a copy of this LeafState with the specified cache value replaced.

Source code in collimator/framework/state.py
156
157
158
159
160
def with_cached_value(self, index: int, value: Array) -> LeafState:
    """Create a copy of this LeafState with the specified cache value replaced."""
    cache = list(self.cache)
    cache[index] = value
    return dataclasses.replace(self, cache=tuple(cache))

with_continuous_state(value)

Create a copy of this LeafState with the continuous state replaced.

Source code in collimator/framework/state.py
102
103
104
105
106
107
def with_continuous_state(self, value: LeafStateComponent) -> LeafState:
    """Create a copy of this LeafState with the continuous state replaced."""
    if value is not None and self.continuous_state is not None:
        value = tree_util.tree_map(self._reshape_like, value, self.continuous_state)

    return dataclasses.replace(self, continuous_state=value)

with_discrete_state(value)

Create a copy of this LeafState with the discrete state replaced.

Source code in collimator/framework/state.py
133
134
135
136
137
138
def with_discrete_state(self, value: LeafStateComponent) -> LeafState:
    """Create a copy of this LeafState with the discrete state replaced."""
    if value is not None and self.discrete_state is not None:
        value = tree_util.tree_map(self._reshape_like, value, self.discrete_state)

    return dataclasses.replace(self, discrete_state=value)

with_mode(value)

Create a copy of this LeafState with the mode replaced.

Source code in collimator/framework/state.py
148
149
150
def with_mode(self, value: int) -> LeafState:
    """Create a copy of this LeafState with the mode replaced."""
    return dataclasses.replace(self, mode=value)

LeafSystem dataclass

Bases: SystemBase

Basic building block for dynamical systems.

A LeafSystem is a minimal component of a system model in collimator, containing no subsystems. Inputs, outputs, state, parameters, updates, etc. can be added to the block using the various declare_* methods. The built-in blocks in collimator.library are all subclasses of LeafSystem, as are any custom blocks defined by the user.

Source code in collimator/framework/leaf_system.py
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
class LeafSystem(SystemBase, metaclass=InitializeParameterResolver):
    """Basic building block for dynamical systems.

    A LeafSystem is a minimal component of a system model in collimator, containing no
    subsystems.  Inputs, outputs, state, parameters, updates, etc. can be added to the
    block using the various `declare_*` methods.  The built-in blocks in
    collimator.library are all subclasses of LeafSystem, as are any custom blocks defined
    by the user."""

    # SystemBase is a dataclass, so we need to call __post_init__ explicitly
    def __post_init__(self):
        super().__post_init__()
        logger.debug(f"Initializing {self.name} [{self.system_id}]")

        # If not None, this defines the shape and data type of the continuous state
        # component.  This value will be used to initialize the context, so it will
        # also serve as the initial value unless explicitly overridden. It will
        # typically be an array, but it can be any PyTree-structured object (list,
        # dict, namedtuple, etc.), provided that the ODE function returns a PyTree
        # of the same structure.
        self._default_continuous_state: LeafStateComponent = None
        self._mass_matrix: Array = None
        self._continuous_state_output_port_idx: int = None

        # The SystemCallback associated with time derivatives of the continuous state.
        # This is initialized in the `declare_continuous_state` method.
        self.ode_callback: SystemCallback = None

        # If not empty, this defines the shape and data type of the discrete state.
        # This value will be used to initialize the context, so it will also serve
        # as the initial value unless explicitly overridden. This will often be an
        # array, but as for the continuous state it can be any PyTree-structured
        # object (list, dict, namedtuple, etc.), provided that the update functions
        # return a PyTree of the same structure.
        self._default_discrete_state: LeafStateComponent = None

        # If not None, the system has a "mode" or "stage" component of the state.
        # In a "state machine" paradigm, this represents the current state of the
        # system (although "state" is obviously used for other things in this case).
        # The mode is an integer value, and the system can declare transitions between
        # modes using the `declare_zero_crossing` method, which in addition to the
        # guard function and reset map also takes optional `start_mode` and `end_mode`
        # arguments.
        self._default_mode: int = None
        self._mode_output_port_idx: int = None

        # Set of "template" values for the sample-and-hold output ports, if known.
        # If not known, these will be `None`, in which case an appropriate value is
        # inferred from upstream during static analysis.
        self._default_cache: List[LeafStateComponent] = []

        # Transition map from (start_mode -> [*end_modes]) indicating which
        # transition events are active in each mode.  This is not used by
        # any logic in the system, but can be useful for debugging.
        self.transition_map: dict[int, List[Tuple[int, ZeroCrossingEvent]]] = {}

        # Set of events that updates at a fixed rate.  Each event has its own period
        # and offset, so "fires" independently of the other events. These can be
        # created using the `declare_periodic_update` method.
        self._state_update_events: List[DiscreteUpdateEvent] = []

        # Set of events that update when a zero-crossing occurs.  Each event has its
        # own guard function and, optionally, reset map, start mode, and end mode.
        # These can be created using the `declare_zero_crossing` method.
        self._zero_crossing_events: List[ZeroCrossingEvent] = []

    def initialize(self, **parameters):
        """Hook for initializing a system. Called during context creation.

        If the parameters are instances of Parameter, they will be resolved.
        If implemented, the function signature should contain all the declared
        parameters.

        This function should not be called directly. It will be called implicitly
        after __init__ with the resolved parameters.
        """
        pass

    @property
    def has_feedthrough_side_effects(self) -> bool:
        # See explanation in `SystemBase.has_feedthrough_side_effects`.  This will
        # almost always be False, but can be overridden in special cases where a
        # feedthrough output is computed via use of `io_callback`.
        return False

    @property
    def has_ode_side_effects(self) -> bool:
        # This will almost always be False for a LeafSystem - Diagram systems
        # have some special logic to do this determination.
        return False

    @property
    def has_continuous_state(self) -> bool:
        return self._default_continuous_state is not None

    @property
    def has_discrete_state(self) -> bool:
        return self._default_discrete_state is not None

    @property
    def has_zero_crossing_events(self) -> bool:
        return len(self._zero_crossing_events) > 0

    #
    # Event handling
    #
    def wrap_callback(
        self, callback: Callable, collect_inputs: bool | list[int] = True
    ) -> Callable:
        """Wrap an update function to unpack local variables and block inputs.

        The callback should have the signature
        `callback(time, state, *inputs, **params) -> result`
        and will be wrapped to have the signature `callback(context) -> result`,
        as expected by the event handling logic.

        This is used internally for declaration methods like
        `declare_periodic_update` so that users can write more intuitive
        block-level update functions without worrying about the "context", and have
        them automatically wrapped to have the right interface.  It can also be
        called directly by users to wrap their own update functions, for example to
        create a callback function for `declare_output_port`.

        The context and state are strictly immutable, so the callback should not
        attempt to change any values in the context or state.  Even in cases where
        it is impossible to _enforce_ this (e.g. a state component is a list, which
        is always mutable in Python), the callback should be careful to avoid direct
        modification of the context or state, which may lead to unexpected behavior
        or JAX tracer errors.

        Args:
            callback (Callable):
                The (pure) function to be wrapped. See above for expected signature.
            collect_inputs (bool):
                If True, the callback will eval input ports to gather input values.
                Normally this should be True, but it can be set to False if the
                return value depends only on the state but not inputs, for
                instance. This helps reduce the number of expressions that need to
                be JIT compiled. Can also be specified as a list of integer port indices.
                Default is True (collect all inputs).

        Returns:
            Callable:
                The wrapped function, with signature `callback(context) -> result`.
        """

        def _wrapped_callback(context: ContextBase) -> LeafStateComponent:
            if isinstance(collect_inputs, bool):
                # If port_indices is None, all inputs will be returned
                port_indices = None if collect_inputs else []
            else:
                port_indices = collect_inputs  # List of specific ports to get

            inputs = self.collect_inputs(context, port_indices)
            leaf_context: LeafContext = context[self.system_id]

            leaf_state = leaf_context.state
            params = leaf_context.parameters
            return callback(context.time, leaf_state, *inputs, **params)

        return _wrapped_callback

    def _passthrough(self, context: ContextBase) -> LeafState:
        """Dummy callback for inactive events."""
        return context[self.system_id].state

    @property
    def state_update_events(self) -> FlatEventCollection:
        return FlatEventCollection(tuple(self._state_update_events))

    @property
    def zero_crossing_events(self) -> LeafEventCollection:
        # The default is for all to be active. Use the `determine_active_guards`
        # method to determine which are active conditioned on the current "mode"
        # or "stage" of the system.
        return LeafEventCollection(tuple(self._zero_crossing_events)).mark_all_active()

    # Inherits docstring from SystemBase
    def eval_zero_crossing_updates(
        self,
        context: ContextBase,
        events: LeafEventCollection,
    ) -> LeafState:
        local_events = events[self.system_id]
        state = context[self.system_id].state

        logger.debug(f"Eval update events for {self.name}")
        logger.debug(f"local events: {local_events}")

        for event in local_events:
            # This is evaluated conditionally on event_data.active
            state = event.handle(context)

            # Store the updated state in the context for this block
            leaf_context = context[self.system_id].with_state(state)

            # Update the context for this block in the overall context
            context = context.with_subcontext(self.system_id, leaf_context)

        # Now `context` contains the updated "plus" state for this block, but
        # this needs to be discarded so that other block updates can also be
        # processed using the "minus" state. This is done by simply returning the
        # "plus" state and discarding the rest of the updated context.
        return state

    # Inherits docstring from SystemBase
    def determine_active_guards(self, root_context: ContextBase) -> LeafEventCollection:
        mode = root_context[self.system_id].mode  # Current system mode

        def _conditionally_activate(
            event: ZeroCrossingEvent,
        ) -> ZeroCrossingEvent:
            # Check to see if the event corresponds to a mode transition
            # If not, just return the event unchanged (will be active)
            if event.active_mode is None:
                return event
            # If the event does correspond to a mode transition, check to see
            # if the event is active in the current mode
            return cond(
                mode == event.active_mode,
                lambda e: e.mark_active(),
                lambda e: e.mark_inactive(),
                event,
            )

        # Apply the conditional activation to all events
        zero_crossing_events = LeafEventCollection(
            tuple(_conditionally_activate(e) for e in self.zero_crossing_events)
        )

        logger.debug(f"Zero-crossing events for {self.name}: {zero_crossing_events}")
        return zero_crossing_events

    @property
    def _flat_callbacks(self) -> List[OutputPort]:
        """Return all of the sample-and-hold output ports in this system."""
        return self.callbacks

    def declare_cache(
        self,
        callback: Callable,
        period: float | Parameter = None,
        offset: float | Parameter = 0.0,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array = None,
        requires_inputs: bool = True,
    ) -> int:
        """Declare a stored computation for the system.

        This method accepts a callback function with the block-level signature
            `callback(time, state, *inputs, **parameters) -> value`
        and wraps it to have the signature
            `callback(context) -> value`

        This callback can optionally be used to define a periodic update event that
        refreshes the cached value.  Other calculations (e.g. sample-and-hold output
        ports) can then depend on the cached value.

        Args:
            callback (Callable):
                The callback function defining the cached computation.
            period (float, optional):
                If not None, the callback function will be used to define a periodic
                update event that refreshes the value. Defaults to None.
            offset (float, optional):
                The offset of the periodic update event. Defaults to 0.0.  Will be ignored
                unless `period` is not None.
            name (str, optional):
                The name of the cached value. Defaults to None.
            default_value (Array, optional):
                The default value of the result, if known. Defaults to None.
            requires_inputs (bool, optional):
                If True, the callback will eval input ports to gather input values.
                This will add a bit to compile time, so setting to False where possible
                is recommended. Defaults to True.
            prerequisites_of_calc (List[DependencyTicket], optional):
                The dependency tickets for the computation. Defaults to None, in which
                case the default is to assume dependency on either (inputs) if
                `requires_inputs` is True, or (nothing) otherwise.

        Returns:
            int: The index of the callback in `system.callbacks`.  The cache index can
                recovered from `system.callbacks[callback_index].cache_index`.
        """
        # The index in the list of system callbacks
        callback_index = len(self.callbacks)

        # This is the index that this cached value will have in state.cache
        cache_index = len(self._default_cache)
        self._default_cache.append(default_value)

        # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
        # default prereq if the update callback doesn't use them
        if prerequisites_of_calc is None:
            if requires_inputs:
                prerequisites_of_calc = [DependencyTicket.u]
            else:
                prerequisites_of_calc = [DependencyTicket.nothing]

        def _update_callback(
            time: Scalar, state: LeafState, *inputs, **parameters
        ) -> LeafState:
            output = callback(time, state, *inputs, **parameters)
            return state.with_cached_value(cache_index, output)

        _update_callback = self.wrap_callback(
            _update_callback, collect_inputs=requires_inputs
        )

        if period is None:
            event = None

        else:
            # The cache has a periodic event updating its value defined by the callback
            event = DiscreteUpdateEvent(
                system_id=self.system_id,
                event_data=PeriodicEventData(
                    period=period, offset=offset, active=False
                ),
                name=f"{self.name}:cache_update_{cache_index}_",
                callback=_update_callback,
                passthrough=self._passthrough,
            )

        if name is None:
            name = f"cache_{cache_index}"

        sys_callback = SystemCallback(
            callback=_update_callback,
            system=self,
            callback_index=callback_index,
            name=name,
            prerequisites_of_calc=prerequisites_of_calc,
            event=event,
            default_value=default_value,
            cache_index=cache_index,
        )
        self.callbacks.append(sys_callback)

        return callback_index

    # NOTE: we can only declare one continuous state per system because each
    # call will overwrite self._default_continuous_state
    def declare_continuous_state(
        self,
        shape: ShapeLike = None,
        default_value: Array = None,
        dtype: DTypeLike = None,
        ode: Callable = None,
        mass_matrix: Array = None,
        as_array: bool = True,
        requires_inputs: bool = True,
        prerequisites_of_calc: List[DependencyTicket] = None,
    ):
        """Declare a continuous state component for the system."""

        self.ode_callback = SystemCallback(
            callback=None,
            system=self,
            callback_index=len(self.callbacks),
            name=f"{self.name}_ode",
            prerequisites_of_calc=prerequisites_of_calc,
        )
        self.callbacks.append(self.ode_callback)
        callback_idx = len(self.callbacks) - 1

        # FIXME: this is to preserve some backward compatibility while we decouple
        # declaration from configuration. Declaration should not have to call
        # configuration.
        if default_value is not None or shape is not None:
            self.configure_continuous_state(
                callback_idx,
                shape=shape,
                default_value=default_value,
                dtype=dtype,
                ode=ode,
                mass_matrix=mass_matrix,
                as_array=as_array,
                requires_inputs=requires_inputs,
                prerequisites_of_calc=prerequisites_of_calc,
            )

        return callback_idx

    def configure_continuous_state(
        self,
        callback_idx: int,
        shape: ShapeLike = None,
        default_value: Array = None,
        dtype: DTypeLike = None,
        ode: Callable = None,
        mass_matrix: Array = None,
        as_array: bool = True,
        requires_inputs: bool = True,
        prerequisites_of_calc: List[DependencyTicket] = None,
    ):
        """Configure a continuous state component for the system.

        The `ode` callback computes the time derivative of the continuous state based on the
        current time, state, and any additional inputs. If `ode` is not provided, a default
        zero vector of the same size as the continuous state is used. If provided, the `ode`
        callback should have the signature `ode(time, state, *inputs, **params) -> xcdot`.

        Args:
            callback_idx (int):
                The index of the callback in the system's callback list.
            shape (ShapeLike, optional):
                The shape of the continuous state vector. Defaults to None.
            default_value (Array, optional):
                The initial value of the continuous state vector. Defaults to None.
            dtype (DTypeLike, optional):
                The data type of the continuous state vector. Defaults to None.
            ode (Callable, optional):
                The callback for computing the time derivative of the continuous state.
                Should have the signature:
                    `ode(time, state, *inputs, **parameters) -> xcdot`.
                Defaults to None.
            mass_matrix (Array, optional):
                The mass matrix for the continuous state. Defaults to None. If
                provided, must be a square matrix with the same shape as the
                continuous state.  Using a mass matrix different from the identity
                in any LeafSystem will require the use of a compatible continuous-time
                solver (currently only BDF is supported).  Currently mass matrices are
                also only supported for scalar- or vector-valued continuous states (
                i.e. no matrices or other PyTree-structured states).
            as_array (bool, optional):
                If True, treat the default_value as an array-like (cast if necessary).
                Otherwise, it will be stored as the default state without modification.
            requires_inputs (bool, optional):
                If True, indicates that the ODE computation requires inputs.
            prerequisites_of_calc (List[DependencyTicket], optional):
                The dependency tickets for the ODE computation. Defaults to None, in
                which case the assumption is a dependency on either (time, continuous
                state) if `requires_inputs` is False, otherwise (time, continuous state,
                inputs.

        Raises:
            AssertionError:
                If neither shape nor default_value is provided, or if the mass matrix
                is inconsistent with the continuous state.

        Notes:
            (1) Only one of `shape` and `default_value` should be provided. If `default_value`
            is provided, it will be used as the initial value of the continuous state. If
            `shape` is provided, the initial value will be a zero vector of the given shape
            and specified dtype.
        """

        if prerequisites_of_calc is None:
            prerequisites_of_calc = [DependencyTicket.time, DependencyTicket.xc]
            if requires_inputs:
                prerequisites_of_calc.append(DependencyTicket.u)

        if as_array:
            default_value = utils.make_array(default_value, dtype=dtype, shape=shape)

        logger.debug(f"In block {self.name} [{self.system_id}]: {default_value=}")

        # Tree-map the default value to ensure that it is an array-like with the
        # correct shape and dtype. This is necessary because the default value
        # may be a list, tuple, or other PyTree-structured object.
        default_value = tree_util.tree_map(cnp.asarray, default_value)

        self._default_continuous_state = default_value
        if self._continuous_state_output_port_idx is not None:
            port = self.output_ports[self._continuous_state_output_port_idx]
            port.default_value = default_value
            self._default_cache[port.cache_index] = default_value

        if ode is None:
            # If no ODE is specified, return a zero vector of the same size as the
            # continuous state. This will break if the continuous state is
            # a named tuple, in which case a custom ODE must be provided.
            assert as_array, "Must provide custom ODE for non-array continuous state"

            def ode(time, state, *inputs, **parameters):
                return cnp.zeros_like(default_value)

        # Wrap the ode function to accept a context and return the time derivatives.
        ode = self.wrap_callback(ode)

        # Declare the time derivative function as a system callback so that its
        # dependencies can be tracked in the system dependency graph
        self.ode_callback._callback = ode
        self.ode_callback.prerequisites_of_calc = prerequisites_of_calc

        # Override the default `eval_time_derivatives` to use the wrapped ODE function
        self.eval_time_derivatives = self.ode_callback.eval

        if mass_matrix is not None:
            # Check that the state is a vector or scalar
            assert as_array, "Mass matrix only supported for array-valued states"
            assert (
                len(default_value.shape) <= 1
            ), "Mass matrix only supported for scalar or vector continuous states"
            n = default_value.size
            assert mass_matrix.shape in ((n, n), (n,)), (
                "Mass matrix must be either a square matrix or vector of the same "
                f"size as the continuous state, but got {mass_matrix.shape} for "
                f"continuous state of shape {default_value.shape}."
            )
            if len(mass_matrix.shape) == 1:
                mass_matrix = np.diag(mass_matrix)
            else:
                mass_matrix = np.asarray(mass_matrix)

            # If we end up with an identity matrix, we can just ignore the mass
            # matrix and use the default mass matrix (which is None).  This will
            # allow us to continue using explicit ODE solvers.
            nontrivial_mass_matrix = not np.allclose(mass_matrix, np.eye(n))
            if not nontrivial_mass_matrix:
                mass_matrix = None

        self._mass_matrix = mass_matrix

    @property
    def mass_matrix(self) -> Array:
        # When this is called, an array return value is expected, so we can safely
        # return the mass matrix as an array, even if the internal value is None.
        if self._default_continuous_state is None:
            return None

        if self._mass_matrix is not None:
            return self._mass_matrix

        # Currently only scalar- or vector-valued continuous states are supported,
        # so check that the continuous state (or all tree leaves if tree-structured)
        # is a scalar or vector, and return corresponding identity matrices.
        xc_leaves = tree_util.tree_leaves(self._default_continuous_state)
        if not all(len(xc.shape) <= 1 for xc in xc_leaves):
            raise ValueError(
                "Mass matrix DAEs are only supported when the continuous state is "
                f"scalar- or vector-valued.  System {self.name} has non-vector "
                "continuous state with default value "
                f"{self._default_continuous_state}."
            )

        # Now we are guaranteed that the continuous state is a scalar or vector, so
        # we can return the corresponding (tree-structured) identity matrix.
        return jax.tree.map(lambda x: np.eye(x.size), self._default_continuous_state)

    @property
    def has_mass_matrix(self) -> bool:
        # Does the system have a nontrivial mass matrix?  This will return
        # False if the mass matrix is None or the identity matrix, since
        # the internal _mass_matrix attribute is set to None during
        # continuous state creation in the case where the mass matrix is
        # the identity.
        return self._mass_matrix is not None

    # FIXME: this doesn't support multiple discrete states as the docstring
    # suggests.
    def declare_discrete_state(
        self,
        shape: ShapeLike = None,
        default_value: Array | Parameter = None,
        dtype: DTypeLike = None,
        as_array: bool = True,
    ):
        """Declare a new discrete state component for the system.

        The discrete state is a component of the system's state that can be updated
        at specific events, such as zero-crossings or periodic updates. Multiple
        discrete states can be declared, and each is associated with a unique index.
        The index is used to access and update the corresponding discrete state in
        the system's context during event handling.

        The declared discrete state is initialized with either the provided default
        value or zeros of the correct shape and dtype.

        Args:
            shape (ShapeLike, optional):
                The shape of the discrete state. Defaults to None.
            default_value (Array, optional):
                The initial value of the discrete state. Defaults to None.
            dtype (DTypeLike, optional):
                The data type of the discrete state. Defaults to None.
            as_array (bool, optional):
                If True, treat the default_value as an array-like (cast if necessary).
                Otherwise, it will be stored as the default state without modification.

        Raises:
            AssertionError:
                If as_array is True and neither shape nor default_value is provided.

        Notes:
            (1) Only one of `shape` and `default_value` should be provided. If
            `default_value` is provided, it will be used as the initial value of the
            continuous state. If `shape` is provided, the initial value will be a
            zero vector of the given shape and specified dtype.

            (2) Use `declare_periodic_update` to declare an update event that
            modifies the discrete state at a recurring interval.
        """
        if as_array:
            default_value = utils.make_array(default_value, dtype=dtype, shape=shape)

        # Tree-map the default value to ensure that it is an array-like with the
        # correct shape and dtype. This is necessary because the default value
        # may be a list, tuple, or other PyTree-structured object.
        default_value = tree_util.tree_map(cnp.asarray, default_value)

        self._default_discrete_state = default_value

    def configure_discrete_state_default_value(
        self, default_value: Array, as_array: bool = True
    ):
        if as_array:
            dtype = self._default_discrete_state.dtype
            shape = self._default_discrete_state.shape
            default_value = utils.make_array(default_value, dtype=dtype, shape=shape)

        # Tree-map the default value to ensure that it is an array-like with the
        # correct shape and dtype. This is necessary because the default value
        # may be a list, tuple, or other PyTree-structured object.
        default_value = tree_util.tree_map(cnp.asarray, default_value)

        _check_values_compatible(self._default_discrete_state, default_value)

        self._default_discrete_state = default_value

    #
    # I/O declaration
    #
    def declare_output_port(
        self,
        callback: Callable = None,
        period: float = None,
        offset: float = 0.0,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array = None,
        requires_inputs: bool | list[int] = True,
    ) -> int:
        """Declare an output port in the LeafSystem.

        This method accepts a callback function with the block-level signature
            `callback(time, state, *inputs, **parameters) -> value`
        and wraps it to the signature expected by SystemBase.declare_output_port:
            `callback(context) -> value`

        Args:
            callback (Callable):
                The callback function defining the output port.
            period (float, optional):
                If not None, the port will act as a "sample-and-hold", with the
                callback function used to define a periodic update event that refreshes
                the value that will be returned by the port. Typically this should
                match the update period of some associated update event in the system.
                Defaults to None.
            offset (float, optional):
                The offset of the periodic update event. Defaults to 0.0.  Will be ignored
                unless `period` is not None.
            name (str, optional):
                The name of the output port. Defaults to None.
            default_value (Array, optional):
                The default value of the output port, if known. Defaults to None.
            requires_inputs (bool | list[int], optional):
                If True, the callback will eval input ports to gather input values.
                This will add a bit to compile time, so setting to False where possible
                is recommended. Can also be specified as a list of integer port indices.
                Defaults to True (collect all inputs).
            prerequisites_of_calc (List[DependencyTicket], optional):
                The dependency tickets for the output port computation.  Defaults to
                None, in which case the assumption is a dependency on either (nothing)
                if `requires_inputs` is False otherwise (inputs).

        Returns:
            int: The index of the declared output port.
        """

        if default_value is not None:
            default_value = cnp.array(default_value)

        cache_index = None
        if period is not None:
            # The output port will be of "sample-and-hold" type, so we have to declare a
            # periodic event to update the value.  The callback will be used to define the
            # update event, and the output callback will simply return the stored value.

            # This is the index that this port value will have in state.cache
            cache_index = len(self._default_cache)
            self._default_cache.append(default_value)

        output_port_idx = super().declare_output_port(
            callback, name=name, cache_index=cache_index
        )

        self.configure_output_port(
            output_port_idx,
            callback,
            period=period,
            offset=offset,
            prerequisites_of_calc=prerequisites_of_calc,
            default_value=default_value,
            requires_inputs=requires_inputs,
        )

        return output_port_idx

    def configure_output_port(
        self,
        port_index: int,
        callback: Callable,
        period: float = None,
        offset: float = 0.0,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array = None,
        requires_inputs: bool = True,
    ):
        """Configure an output port in the LeafSystem.

        See `declare_output_port` for a description of the arguments.

        Args:
            port_index (int):
                The index of the output port to configure.

        Returns:
            None
        """
        if default_value is not None:
            default_value = cnp.array(default_value)

        # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
        # default prereq if the output callback doesn't use them
        if prerequisites_of_calc is None:
            if requires_inputs:
                prerequisites_of_calc = [DependencyTicket.u]
            else:
                prerequisites_of_calc = [DependencyTicket.nothing]

        if period is None:
            event = None
            _output_callback = self.wrap_callback(
                callback, collect_inputs=requires_inputs
            )
            cache_index = None

        else:
            # The output port will be of "sample-and-hold" type, so we have to declare a
            # periodic event to update the value.  The callback will be used to define the
            # update event, and the output callback will simply return the stored value.

            # This is the index that this port value will have in state.cache
            cache_index = self.output_ports[port_index].cache_index
            if cache_index is None:
                cache_index = len(self._default_cache)
                self._default_cache.append(default_value)

            def _output_callback(context: ContextBase) -> Array:
                state = context[self.system_id].state
                return state.cache[cache_index]

            def _update_callback(
                time: Scalar, state: LeafState, *inputs, **parameters
            ) -> LeafState:
                output = callback(time, state, *inputs, **parameters)
                return state.with_cached_value(cache_index, output)

            _update_callback = self.wrap_callback(
                _update_callback, collect_inputs=requires_inputs
            )

            # Create the associated update event
            event = DiscreteUpdateEvent(
                system_id=self.system_id,
                event_data=PeriodicEventData(
                    period=period, offset=offset, active=False
                ),
                name=f"{self.name}:output_{cache_index}",
                callback=_update_callback,
                passthrough=self._passthrough,
            )

            # Note that in this case the "prerequisites of calc" will correspond to the
            # prerequisites of the update event, not the literal output callback itself.
            # However, these can be used to determine dependencies for the update event
            # via the output port.

        super().configure_output_port(
            port_index,
            _output_callback,
            prerequisites_of_calc=prerequisites_of_calc,
            default_value=default_value,
            event=event,
            cache_index=cache_index,
        )

    def configure_continuous_state_default_value(
        self, callback_idx: int, default_value: Array, as_array: bool = True
    ):
        if as_array:
            dtype = self._default_continuous_state.dtype
            shape = self._default_continuous_state.shape
            default_value = utils.make_array(default_value, dtype=dtype, shape=shape)

        # Tree-map the default value to ensure that it is an array-like with the
        # correct shape and dtype. This is necessary because the default value
        # may be a list, tuple, or other PyTree-structured object.
        default_value = tree_util.tree_map(cnp.asarray, default_value)

        _check_values_compatible(self._default_continuous_state, default_value)

        self._default_continuous_state = default_value
        if self._continuous_state_output_port_idx is not None:
            port = self.output_ports[self._continuous_state_output_port_idx]
            port.default_value = default_value
            self._default_cache[port.cache_index] = default_value

    def configure_output_port_default_value(
        self,
        port_index: int,
        default_value: Array,
    ):
        port = self.output_ports[port_index]
        if port.event is None:
            logger.warning(
                "period is None so default_value is not used for port %d in block %s",
                port_index,
                self.name,
            )
            return
        default_value = cnp.array(default_value)
        cache_index = self.output_ports[port_index].cache_index

        if cache_index is None:
            raise ValueError(
                "Output port does not have a cache index, so default value cannot be set"
            )

        _check_values_compatible(self._default_cache[cache_index], default_value)
        self._default_cache[cache_index] = default_value

    def declare_continuous_state_output(
        self,
        name: str = None,
    ) -> int:
        """Declare a continuous state output port in the system.

        This method creates a new block-level output port which returns the full
        continuous state of the system.

        Args:
            name (str, optional):
                The name of the output port. Defaults to None (autogenerate name).

        Returns:
            int: The index of the new output port.
        """
        if self._continuous_state_output_port_idx is not None:
            raise ValueError("Continuous state output port already declared")

        def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
            return state.continuous_state

        self._continuous_state_output_port_idx = self.declare_output_port(
            _callback,
            name=name,
            prerequisites_of_calc=[DependencyTicket.xc],
            default_value=self._default_continuous_state,
            requires_inputs=False,
        )
        return self._continuous_state_output_port_idx

    def declare_mode_output(self, name: str = None) -> int:
        """Declare a mode output port in the system.

        This method creates a new block-level output port which returns the component
        of the system's state corresponding to the discrete "mode" or "stage".

        Args:
            name (str, optional):
                The name of the output port. Defaults to None.

        Returns:
            int:
                The index of the declared mode output port.
        """

        def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
            return state.mode

        self._mode_output_port_idx = self.declare_output_port(
            _callback,
            name=name,
            prerequisites_of_calc=[DependencyTicket.mode],
            default_value=self._default_mode,
            requires_inputs=False,
        )

        return self._mode_output_port_idx

    #
    # Event declaration
    #
    def declare_periodic_update(
        self,
        callback: Callable = None,
        period: Scalar | Parameter = None,
        offset: Scalar | Parameter = None,
        enable_tracing: bool = None,
    ):
        self._state_update_events.append(None)
        event_idx = len(self._state_update_events) - 1

        # FIXME: this is to preserve some backward compatibility while we decouple
        # declaration from configuration. Declaration should not have to call
        # configuration.
        if callback is not None:
            self.configure_periodic_update(
                event_idx,
                callback,
                period,
                offset,
                enable_tracing=enable_tracing,
            )
        return event_idx

    def configure_periodic_update(
        self,
        event_index: int,
        callback: Callable,
        period: Scalar | Parameter,
        offset: Scalar | Parameter,
        enable_tracing: bool = None,
    ):
        """Configure an existing periodic update event.

        The event will be triggered at regular intervals defined by the period and
        offset parameters. The callback should have the signature
        `callback(time, state, *inputs, **params) -> xd_plus`, where `xd_plus` is the
        updated value of the discrete state.

        This callback should be written to compute the "plus" value of the discrete
        state component given the "minus" values of all state components and inputs.

        Args:
            event_index (int):
                The index of the event to configure.
            callback (Callable):
                The callback function defining the update.
            period (Scalar):
                The period at which the update event occurs.
            offset (Scalar):
                The offset at which the first occurrence of the event is triggered.
            enable_tracing (bool, optional):
                If True, enable tracing for this event. Defaults to None.
        """
        _wrapped_callback = self.wrap_callback(callback)

        def _callback(context: ContextBase) -> LeafState:
            xd = _wrapped_callback(context)
            return context[self.system_id].state.with_discrete_state(xd)

        if enable_tracing is None:
            enable_tracing = True

        event = DiscreteUpdateEvent(
            system_id=self.system_id,
            name=f"{self.name}:periodic_update",
            event_data=PeriodicEventData(period=period, offset=offset, active=False),
            callback=_callback,
            passthrough=self._passthrough,
            enable_tracing=enable_tracing,
        )
        self._state_update_events[event_index] = event

    def declare_default_mode(self, mode: int):
        self._default_mode = mode

    def configure_default_mode(self, mode: int):
        self._default_mode = mode
        if self._mode_output_port_idx:
            self.configure_output_port_default_value(self._mode_output_port_idx, mode)

    def declare_zero_crossing(
        self,
        guard: Callable,
        reset_map: Callable = None,
        start_mode: int = None,
        end_mode: int = None,
        direction: str = "crosses_zero",
        terminal: bool = False,
        name: str = None,
        enable_tracing: bool = None,
    ):
        """Declare an event triggered by a zero-crossing of a guard function.

        Optionally, the system can also transition between discrete modes
        If `start_mode` and `end_mode` are specified, the system will transition
        from `start_mode` to `end_mode` when the event is triggered according to `guard`.
        This event will be active conditionally on `state.mode == start_mode` and when
        triggered will result in applying the reset map. In addition, the mode will be
        updated to `end_mode`.

        If `start_mode` and `end_mode` are not specified, the event will always be active
        and will not result in a mode transition.

        The guard function should have the signature:
            `guard(time, state, *inputs, **parameters) -> float`

        and the reset map should have the signature of an unrestricted update:
            `reset_map(time, state, *inputs, **parameters) -> state`

        Args:
            guard (Callable):
                The guard function which triggers updates on zero crossing.
            reset_map (Callable, optional):
                The reset map which is applied when the event is triggered. If None
                (default), no reset is applied.
            start_mode (int, optional):
                The mode or stage of the system in which the guard will be
                actively monitored. If None (default), the event will always be
                active.
            end_mode (int, optional):
                The mode or stage of the system to which the system will transition
                when the event is triggered. If start_mode is None, this is ignored.
                Otherwise it _must_ be specified, though it can be the same as
                start_mode.
            direction (str, optional):
                The direction of the zero crossing. Options are "crosses_zero"
                (default), "positive_then_non_positive", "negative_then_non_negative",
                and "edge_detection".  All except edge detection operate on continuous
                signals; edge detection operates on boolean signals and looks for a
                jump from False to True or vice versa.
            terminal (bool, optional):
                If True, the event will halt simulation if and when the zero-crossing
                occurs. If this event is triggered the reset map will still be applied
                as usual prior to termination. Defaults to False.
            name (str, optional):
                The name of the event. Defaults to None.
            enable_tracing (bool, optional):
                If True, enable tracing for this event. Defaults to None.

        Notes:
            By default the system state does not have a "mode" component, so in
            order to declare "state transitions" with non-null start and end modes,
            the user must first call `declare_default_mode` to set the default mode
            to be some integer (initial condition for the system).
        """

        logger.debug(
            f"Declaring transition for {self.name} with guard {guard} and reset map {reset_map}"
        )

        if enable_tracing is None:
            enable_tracing = True

        if start_mode is not None or end_mode is not None:
            assert (
                self._default_mode is not None
            ), "System has no mode: call `declare_default_mode` before transitions."
            assert isinstance(start_mode, int) and isinstance(end_mode, int)

        # Wrap the reset map with a mode update if necessary
        def _reset_and_update_mode(
            time: Scalar, state: LeafState, *inputs, **parameters
        ) -> LeafState:
            if reset_map is not None:
                state = reset_map(time, state, *inputs, **parameters)
            logger.debug(f"Updating mode from {state.mode} to {end_mode}")

            # If the start and end modes are declared, update the mode
            if start_mode is not None:
                logger.debug(f"Updating mode from {state.mode} to {end_mode}")
                state = state.with_mode(end_mode)

            return state

        _wrapped_guard = self.wrap_callback(guard)
        _wrapped_reset = _wrap_reset_map(
            self, _reset_and_update_mode, _wrapped_guard, terminal
        )

        event = ZeroCrossingEvent(
            system_id=self.system_id,
            guard=_wrapped_guard,
            reset_map=_wrapped_reset,
            passthrough=self._passthrough,
            direction=direction,
            is_terminal=terminal,
            name=name,
            event_data=ZeroCrossingEventData(active=True, triggered=False),
            enable_tracing=enable_tracing,
            active_mode=start_mode,
        )

        event_index = len(self._zero_crossing_events)
        self._zero_crossing_events.append(event)

        # Record the transition in the transition map (for debugging or analysis)
        if start_mode is not None:
            if start_mode not in self.transition_map:
                self.transition_map[start_mode] = []
            self.transition_map[start_mode].append((event_index, event))

    #
    # Initialization
    #
    @property
    def context_factory(self) -> LeafContextFactory:
        return LeafContextFactory(self)

    @property
    def dependency_graph_factory(self) -> LeafDependencyGraphFactory:
        return LeafDependencyGraphFactory(self)

    def create_state(self) -> LeafState:
        # Hook for context creation: get the default state for this system.
        # Users should not need to call this directly - the state will be created
        # as part of the context.  Generally, `system.create_context()` should
        # be all that's necessary for initialization.
        self.reset_default_values(**self.dynamic_parameters)
        return LeafState(
            name=self.name,
            continuous_state=self._default_continuous_state,
            discrete_state=self._default_discrete_state,
            mode=self._default_mode,
            cache=tuple(self._default_cache),
        )

    def initialize_static_data(self, context: ContextBase):
        # Try to infer any missing default values for "sample-and-hold" output ports
        # and any other cached computations.
        cached_callbacks: list[SystemCallback] = [
            cb for cb in self.callbacks if cb.cache_index is not None
        ]

        for callback in cached_callbacks:
            i = callback.cache_index
            if self._default_cache[i] is None:
                try:
                    if isinstance(callback, OutputPort):
                        # Try to eval the callback for the _event_ (not the output
                        # port return function), which would return a value of the
                        # right data type for the output port, provided it is connected
                        _eval = callback.event.callback
                    else:
                        # If it's not an output port, the callback function evaluation
                        # should return the correct data type.
                        _eval = callback.eval

                    state: LeafState = _eval(context)
                    y = state.cache[i]
                    self._default_cache[i] = y
                    local_context = context[self.system_id].with_cached_value(i, y)
                    context = context.with_subcontext(self.system_id, local_context)
                except UpstreamEvalError:
                    logger.debug(
                        "%s.initialize_static_data: UpstreamEvalError. "
                        "Continuing without default value initialization.",
                        self.name,
                    )

        return context

    def _create_dependency_cache(self) -> dict[int, CallbackTracer]:
        cache = {}
        for source in self.callbacks:
            cache[source.callback_index] = CallbackTracer(ticket=source.ticket)
        return cache

    # Inherits docstring from SystemBase.get_feedthrough
    def get_feedthrough(self) -> List[Tuple[int, int]]:
        # NOTE: This implementation is basically a direct port of the Drake algorithm

        if self.dependency_graph is None:
            raise ValueError("Must create dependency graph first.")

        # If we already did this or it was set manually, return the stored value
        if self.feedthrough_pairs is not None:
            return self.feedthrough_pairs

        feedthrough = []  # Confirmed feedthrough pairs (input, output)

        # First collect all possible feedthrough pairs
        unknown: Set[Tuple[int, int]] = set()
        for iport in self.input_ports:
            for oport in self.output_ports:
                unknown.add((iport.index, oport.index))

        if len(unknown) == 0:
            return feedthrough

        # Create a local context and "cache".  The cache here just contains CallbackTracer
        # objects that can be used to trace dependencies through the system, but
        # otherwise don't store any actual values.  This is different from any "cached"
        # computations that might be stored in the state for reuse by multiple ports or
        # downstream calculations within the system.
        #
        # This cache will only contain local sources - this is fine since we're just
        # testing local input -> output paths for this system.
        cache = self._create_dependency_cache()

        original_unknown = unknown.copy()
        for pair in original_unknown:
            u, v = pair
            output_port = self.output_ports[v]
            input_port = self.input_ports[u]

            # If output prerequisites are unspecified, this tells us nothing
            if DependencyTicket.all_sources in output_port.prerequisites_of_calc:
                continue

            # Determine feedthrough dependency via cache invalidation
            cache = _mark_up_to_date(cache, output_port.callback_index)

            # Notify subscribers of a value change in the input, invalidating all
            # downstream cache values
            input_tracker = self.dependency_graph[input_port.ticket]
            cache = input_tracker.notify_subscribers(
                cache, self.dependency_graph, local_only=True
            )

            # If the output cache is now out of date, this is a feedthrough path
            if cache[output_port.callback_index].is_out_of_date:
                feedthrough.append(pair)

            # Regardless of the result of the caching, the pair is no longer unknown
            unknown.remove(pair)

            # Reset the output cache to out-of-date in case other inputs also
            # feed through to this output.
            cache = _mark_out_of_date(cache, output_port.callback_index)

        logger.debug(f"{self.name} feedthrough pairs: {feedthrough}")

        # Conservatively assume everything still unknown is feedthrough
        for pair in unknown:
            feedthrough.append(pair)

        self.feedthrough_pairs = feedthrough
        return self.feedthrough_pairs

    def reset_default_values(self, **dynamic_parameters):
        """This function is used to reset default values for
        continuous/discrete states, ports and mode based on dynamic parameters.
        It is called in `create_state()` and used to reset states in ensemble sims
        and optimization with the context method `with_new_state()`.

        Note that dtypes and shapes can't be changed after initialization because
        the diagram may already have been jax-compiled. Only values may change.
        """
        pass

configure_continuous_state(callback_idx, shape=None, default_value=None, dtype=None, ode=None, mass_matrix=None, as_array=True, requires_inputs=True, prerequisites_of_calc=None)

Configure a continuous state component for the system.

The ode callback computes the time derivative of the continuous state based on the current time, state, and any additional inputs. If ode is not provided, a default zero vector of the same size as the continuous state is used. If provided, the ode callback should have the signature ode(time, state, *inputs, **params) -> xcdot.

Parameters:

Name Type Description Default
callback_idx int

The index of the callback in the system's callback list.

required
shape ShapeLike

The shape of the continuous state vector. Defaults to None.

None
default_value Array

The initial value of the continuous state vector. Defaults to None.

None
dtype DTypeLike

The data type of the continuous state vector. Defaults to None.

None
ode Callable

The callback for computing the time derivative of the continuous state. Should have the signature: ode(time, state, *inputs, **parameters) -> xcdot. Defaults to None.

None
mass_matrix Array

The mass matrix for the continuous state. Defaults to None. If provided, must be a square matrix with the same shape as the continuous state. Using a mass matrix different from the identity in any LeafSystem will require the use of a compatible continuous-time solver (currently only BDF is supported). Currently mass matrices are also only supported for scalar- or vector-valued continuous states ( i.e. no matrices or other PyTree-structured states).

None
as_array bool

If True, treat the default_value as an array-like (cast if necessary). Otherwise, it will be stored as the default state without modification.

True
requires_inputs bool

If True, indicates that the ODE computation requires inputs.

True
prerequisites_of_calc List[DependencyTicket]

The dependency tickets for the ODE computation. Defaults to None, in which case the assumption is a dependency on either (time, continuous state) if requires_inputs is False, otherwise (time, continuous state, inputs.

None

Raises:

Type Description
AssertionError

If neither shape nor default_value is provided, or if the mass matrix is inconsistent with the continuous state.

Notes

(1) Only one of shape and default_value should be provided. If default_value is provided, it will be used as the initial value of the continuous state. If shape is provided, the initial value will be a zero vector of the given shape and specified dtype.

Source code in collimator/framework/leaf_system.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
def configure_continuous_state(
    self,
    callback_idx: int,
    shape: ShapeLike = None,
    default_value: Array = None,
    dtype: DTypeLike = None,
    ode: Callable = None,
    mass_matrix: Array = None,
    as_array: bool = True,
    requires_inputs: bool = True,
    prerequisites_of_calc: List[DependencyTicket] = None,
):
    """Configure a continuous state component for the system.

    The `ode` callback computes the time derivative of the continuous state based on the
    current time, state, and any additional inputs. If `ode` is not provided, a default
    zero vector of the same size as the continuous state is used. If provided, the `ode`
    callback should have the signature `ode(time, state, *inputs, **params) -> xcdot`.

    Args:
        callback_idx (int):
            The index of the callback in the system's callback list.
        shape (ShapeLike, optional):
            The shape of the continuous state vector. Defaults to None.
        default_value (Array, optional):
            The initial value of the continuous state vector. Defaults to None.
        dtype (DTypeLike, optional):
            The data type of the continuous state vector. Defaults to None.
        ode (Callable, optional):
            The callback for computing the time derivative of the continuous state.
            Should have the signature:
                `ode(time, state, *inputs, **parameters) -> xcdot`.
            Defaults to None.
        mass_matrix (Array, optional):
            The mass matrix for the continuous state. Defaults to None. If
            provided, must be a square matrix with the same shape as the
            continuous state.  Using a mass matrix different from the identity
            in any LeafSystem will require the use of a compatible continuous-time
            solver (currently only BDF is supported).  Currently mass matrices are
            also only supported for scalar- or vector-valued continuous states (
            i.e. no matrices or other PyTree-structured states).
        as_array (bool, optional):
            If True, treat the default_value as an array-like (cast if necessary).
            Otherwise, it will be stored as the default state without modification.
        requires_inputs (bool, optional):
            If True, indicates that the ODE computation requires inputs.
        prerequisites_of_calc (List[DependencyTicket], optional):
            The dependency tickets for the ODE computation. Defaults to None, in
            which case the assumption is a dependency on either (time, continuous
            state) if `requires_inputs` is False, otherwise (time, continuous state,
            inputs.

    Raises:
        AssertionError:
            If neither shape nor default_value is provided, or if the mass matrix
            is inconsistent with the continuous state.

    Notes:
        (1) Only one of `shape` and `default_value` should be provided. If `default_value`
        is provided, it will be used as the initial value of the continuous state. If
        `shape` is provided, the initial value will be a zero vector of the given shape
        and specified dtype.
    """

    if prerequisites_of_calc is None:
        prerequisites_of_calc = [DependencyTicket.time, DependencyTicket.xc]
        if requires_inputs:
            prerequisites_of_calc.append(DependencyTicket.u)

    if as_array:
        default_value = utils.make_array(default_value, dtype=dtype, shape=shape)

    logger.debug(f"In block {self.name} [{self.system_id}]: {default_value=}")

    # Tree-map the default value to ensure that it is an array-like with the
    # correct shape and dtype. This is necessary because the default value
    # may be a list, tuple, or other PyTree-structured object.
    default_value = tree_util.tree_map(cnp.asarray, default_value)

    self._default_continuous_state = default_value
    if self._continuous_state_output_port_idx is not None:
        port = self.output_ports[self._continuous_state_output_port_idx]
        port.default_value = default_value
        self._default_cache[port.cache_index] = default_value

    if ode is None:
        # If no ODE is specified, return a zero vector of the same size as the
        # continuous state. This will break if the continuous state is
        # a named tuple, in which case a custom ODE must be provided.
        assert as_array, "Must provide custom ODE for non-array continuous state"

        def ode(time, state, *inputs, **parameters):
            return cnp.zeros_like(default_value)

    # Wrap the ode function to accept a context and return the time derivatives.
    ode = self.wrap_callback(ode)

    # Declare the time derivative function as a system callback so that its
    # dependencies can be tracked in the system dependency graph
    self.ode_callback._callback = ode
    self.ode_callback.prerequisites_of_calc = prerequisites_of_calc

    # Override the default `eval_time_derivatives` to use the wrapped ODE function
    self.eval_time_derivatives = self.ode_callback.eval

    if mass_matrix is not None:
        # Check that the state is a vector or scalar
        assert as_array, "Mass matrix only supported for array-valued states"
        assert (
            len(default_value.shape) <= 1
        ), "Mass matrix only supported for scalar or vector continuous states"
        n = default_value.size
        assert mass_matrix.shape in ((n, n), (n,)), (
            "Mass matrix must be either a square matrix or vector of the same "
            f"size as the continuous state, but got {mass_matrix.shape} for "
            f"continuous state of shape {default_value.shape}."
        )
        if len(mass_matrix.shape) == 1:
            mass_matrix = np.diag(mass_matrix)
        else:
            mass_matrix = np.asarray(mass_matrix)

        # If we end up with an identity matrix, we can just ignore the mass
        # matrix and use the default mass matrix (which is None).  This will
        # allow us to continue using explicit ODE solvers.
        nontrivial_mass_matrix = not np.allclose(mass_matrix, np.eye(n))
        if not nontrivial_mass_matrix:
            mass_matrix = None

    self._mass_matrix = mass_matrix

configure_output_port(port_index, callback, period=None, offset=0.0, prerequisites_of_calc=None, default_value=None, requires_inputs=True)

Configure an output port in the LeafSystem.

See declare_output_port for a description of the arguments.

Parameters:

Name Type Description Default
port_index int

The index of the output port to configure.

required

Returns:

Type Description

None

Source code in collimator/framework/leaf_system.py
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
def configure_output_port(
    self,
    port_index: int,
    callback: Callable,
    period: float = None,
    offset: float = 0.0,
    prerequisites_of_calc: List[DependencyTicket] = None,
    default_value: Array = None,
    requires_inputs: bool = True,
):
    """Configure an output port in the LeafSystem.

    See `declare_output_port` for a description of the arguments.

    Args:
        port_index (int):
            The index of the output port to configure.

    Returns:
        None
    """
    if default_value is not None:
        default_value = cnp.array(default_value)

    # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
    # default prereq if the output callback doesn't use them
    if prerequisites_of_calc is None:
        if requires_inputs:
            prerequisites_of_calc = [DependencyTicket.u]
        else:
            prerequisites_of_calc = [DependencyTicket.nothing]

    if period is None:
        event = None
        _output_callback = self.wrap_callback(
            callback, collect_inputs=requires_inputs
        )
        cache_index = None

    else:
        # The output port will be of "sample-and-hold" type, so we have to declare a
        # periodic event to update the value.  The callback will be used to define the
        # update event, and the output callback will simply return the stored value.

        # This is the index that this port value will have in state.cache
        cache_index = self.output_ports[port_index].cache_index
        if cache_index is None:
            cache_index = len(self._default_cache)
            self._default_cache.append(default_value)

        def _output_callback(context: ContextBase) -> Array:
            state = context[self.system_id].state
            return state.cache[cache_index]

        def _update_callback(
            time: Scalar, state: LeafState, *inputs, **parameters
        ) -> LeafState:
            output = callback(time, state, *inputs, **parameters)
            return state.with_cached_value(cache_index, output)

        _update_callback = self.wrap_callback(
            _update_callback, collect_inputs=requires_inputs
        )

        # Create the associated update event
        event = DiscreteUpdateEvent(
            system_id=self.system_id,
            event_data=PeriodicEventData(
                period=period, offset=offset, active=False
            ),
            name=f"{self.name}:output_{cache_index}",
            callback=_update_callback,
            passthrough=self._passthrough,
        )

        # Note that in this case the "prerequisites of calc" will correspond to the
        # prerequisites of the update event, not the literal output callback itself.
        # However, these can be used to determine dependencies for the update event
        # via the output port.

    super().configure_output_port(
        port_index,
        _output_callback,
        prerequisites_of_calc=prerequisites_of_calc,
        default_value=default_value,
        event=event,
        cache_index=cache_index,
    )

configure_periodic_update(event_index, callback, period, offset, enable_tracing=None)

Configure an existing periodic update event.

The event will be triggered at regular intervals defined by the period and offset parameters. The callback should have the signature callback(time, state, *inputs, **params) -> xd_plus, where xd_plus is the updated value of the discrete state.

This callback should be written to compute the "plus" value of the discrete state component given the "minus" values of all state components and inputs.

Parameters:

Name Type Description Default
event_index int

The index of the event to configure.

required
callback Callable

The callback function defining the update.

required
period Scalar

The period at which the update event occurs.

required
offset Scalar

The offset at which the first occurrence of the event is triggered.

required
enable_tracing bool

If True, enable tracing for this event. Defaults to None.

None
Source code in collimator/framework/leaf_system.py
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
def configure_periodic_update(
    self,
    event_index: int,
    callback: Callable,
    period: Scalar | Parameter,
    offset: Scalar | Parameter,
    enable_tracing: bool = None,
):
    """Configure an existing periodic update event.

    The event will be triggered at regular intervals defined by the period and
    offset parameters. The callback should have the signature
    `callback(time, state, *inputs, **params) -> xd_plus`, where `xd_plus` is the
    updated value of the discrete state.

    This callback should be written to compute the "plus" value of the discrete
    state component given the "minus" values of all state components and inputs.

    Args:
        event_index (int):
            The index of the event to configure.
        callback (Callable):
            The callback function defining the update.
        period (Scalar):
            The period at which the update event occurs.
        offset (Scalar):
            The offset at which the first occurrence of the event is triggered.
        enable_tracing (bool, optional):
            If True, enable tracing for this event. Defaults to None.
    """
    _wrapped_callback = self.wrap_callback(callback)

    def _callback(context: ContextBase) -> LeafState:
        xd = _wrapped_callback(context)
        return context[self.system_id].state.with_discrete_state(xd)

    if enable_tracing is None:
        enable_tracing = True

    event = DiscreteUpdateEvent(
        system_id=self.system_id,
        name=f"{self.name}:periodic_update",
        event_data=PeriodicEventData(period=period, offset=offset, active=False),
        callback=_callback,
        passthrough=self._passthrough,
        enable_tracing=enable_tracing,
    )
    self._state_update_events[event_index] = event

declare_cache(callback, period=None, offset=0.0, name=None, prerequisites_of_calc=None, default_value=None, requires_inputs=True)

Declare a stored computation for the system.

This method accepts a callback function with the block-level signature callback(time, state, *inputs, **parameters) -> value and wraps it to have the signature callback(context) -> value

This callback can optionally be used to define a periodic update event that refreshes the cached value. Other calculations (e.g. sample-and-hold output ports) can then depend on the cached value.

Parameters:

Name Type Description Default
callback Callable

The callback function defining the cached computation.

required
period float

If not None, the callback function will be used to define a periodic update event that refreshes the value. Defaults to None.

None
offset float

The offset of the periodic update event. Defaults to 0.0. Will be ignored unless period is not None.

0.0
name str

The name of the cached value. Defaults to None.

None
default_value Array

The default value of the result, if known. Defaults to None.

None
requires_inputs bool

If True, the callback will eval input ports to gather input values. This will add a bit to compile time, so setting to False where possible is recommended. Defaults to True.

True
prerequisites_of_calc List[DependencyTicket]

The dependency tickets for the computation. Defaults to None, in which case the default is to assume dependency on either (inputs) if requires_inputs is True, or (nothing) otherwise.

None

Returns:

Name Type Description
int int

The index of the callback in system.callbacks. The cache index can recovered from system.callbacks[callback_index].cache_index.

Source code in collimator/framework/leaf_system.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def declare_cache(
    self,
    callback: Callable,
    period: float | Parameter = None,
    offset: float | Parameter = 0.0,
    name: str = None,
    prerequisites_of_calc: List[DependencyTicket] = None,
    default_value: Array = None,
    requires_inputs: bool = True,
) -> int:
    """Declare a stored computation for the system.

    This method accepts a callback function with the block-level signature
        `callback(time, state, *inputs, **parameters) -> value`
    and wraps it to have the signature
        `callback(context) -> value`

    This callback can optionally be used to define a periodic update event that
    refreshes the cached value.  Other calculations (e.g. sample-and-hold output
    ports) can then depend on the cached value.

    Args:
        callback (Callable):
            The callback function defining the cached computation.
        period (float, optional):
            If not None, the callback function will be used to define a periodic
            update event that refreshes the value. Defaults to None.
        offset (float, optional):
            The offset of the periodic update event. Defaults to 0.0.  Will be ignored
            unless `period` is not None.
        name (str, optional):
            The name of the cached value. Defaults to None.
        default_value (Array, optional):
            The default value of the result, if known. Defaults to None.
        requires_inputs (bool, optional):
            If True, the callback will eval input ports to gather input values.
            This will add a bit to compile time, so setting to False where possible
            is recommended. Defaults to True.
        prerequisites_of_calc (List[DependencyTicket], optional):
            The dependency tickets for the computation. Defaults to None, in which
            case the default is to assume dependency on either (inputs) if
            `requires_inputs` is True, or (nothing) otherwise.

    Returns:
        int: The index of the callback in `system.callbacks`.  The cache index can
            recovered from `system.callbacks[callback_index].cache_index`.
    """
    # The index in the list of system callbacks
    callback_index = len(self.callbacks)

    # This is the index that this cached value will have in state.cache
    cache_index = len(self._default_cache)
    self._default_cache.append(default_value)

    # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
    # default prereq if the update callback doesn't use them
    if prerequisites_of_calc is None:
        if requires_inputs:
            prerequisites_of_calc = [DependencyTicket.u]
        else:
            prerequisites_of_calc = [DependencyTicket.nothing]

    def _update_callback(
        time: Scalar, state: LeafState, *inputs, **parameters
    ) -> LeafState:
        output = callback(time, state, *inputs, **parameters)
        return state.with_cached_value(cache_index, output)

    _update_callback = self.wrap_callback(
        _update_callback, collect_inputs=requires_inputs
    )

    if period is None:
        event = None

    else:
        # The cache has a periodic event updating its value defined by the callback
        event = DiscreteUpdateEvent(
            system_id=self.system_id,
            event_data=PeriodicEventData(
                period=period, offset=offset, active=False
            ),
            name=f"{self.name}:cache_update_{cache_index}_",
            callback=_update_callback,
            passthrough=self._passthrough,
        )

    if name is None:
        name = f"cache_{cache_index}"

    sys_callback = SystemCallback(
        callback=_update_callback,
        system=self,
        callback_index=callback_index,
        name=name,
        prerequisites_of_calc=prerequisites_of_calc,
        event=event,
        default_value=default_value,
        cache_index=cache_index,
    )
    self.callbacks.append(sys_callback)

    return callback_index

declare_continuous_state(shape=None, default_value=None, dtype=None, ode=None, mass_matrix=None, as_array=True, requires_inputs=True, prerequisites_of_calc=None)

Declare a continuous state component for the system.

Source code in collimator/framework/leaf_system.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def declare_continuous_state(
    self,
    shape: ShapeLike = None,
    default_value: Array = None,
    dtype: DTypeLike = None,
    ode: Callable = None,
    mass_matrix: Array = None,
    as_array: bool = True,
    requires_inputs: bool = True,
    prerequisites_of_calc: List[DependencyTicket] = None,
):
    """Declare a continuous state component for the system."""

    self.ode_callback = SystemCallback(
        callback=None,
        system=self,
        callback_index=len(self.callbacks),
        name=f"{self.name}_ode",
        prerequisites_of_calc=prerequisites_of_calc,
    )
    self.callbacks.append(self.ode_callback)
    callback_idx = len(self.callbacks) - 1

    # FIXME: this is to preserve some backward compatibility while we decouple
    # declaration from configuration. Declaration should not have to call
    # configuration.
    if default_value is not None or shape is not None:
        self.configure_continuous_state(
            callback_idx,
            shape=shape,
            default_value=default_value,
            dtype=dtype,
            ode=ode,
            mass_matrix=mass_matrix,
            as_array=as_array,
            requires_inputs=requires_inputs,
            prerequisites_of_calc=prerequisites_of_calc,
        )

    return callback_idx

declare_continuous_state_output(name=None)

Declare a continuous state output port in the system.

This method creates a new block-level output port which returns the full continuous state of the system.

Parameters:

Name Type Description Default
name str

The name of the output port. Defaults to None (autogenerate name).

None

Returns:

Name Type Description
int int

The index of the new output port.

Source code in collimator/framework/leaf_system.py
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
def declare_continuous_state_output(
    self,
    name: str = None,
) -> int:
    """Declare a continuous state output port in the system.

    This method creates a new block-level output port which returns the full
    continuous state of the system.

    Args:
        name (str, optional):
            The name of the output port. Defaults to None (autogenerate name).

    Returns:
        int: The index of the new output port.
    """
    if self._continuous_state_output_port_idx is not None:
        raise ValueError("Continuous state output port already declared")

    def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
        return state.continuous_state

    self._continuous_state_output_port_idx = self.declare_output_port(
        _callback,
        name=name,
        prerequisites_of_calc=[DependencyTicket.xc],
        default_value=self._default_continuous_state,
        requires_inputs=False,
    )
    return self._continuous_state_output_port_idx

declare_discrete_state(shape=None, default_value=None, dtype=None, as_array=True)

Declare a new discrete state component for the system.

The discrete state is a component of the system's state that can be updated at specific events, such as zero-crossings or periodic updates. Multiple discrete states can be declared, and each is associated with a unique index. The index is used to access and update the corresponding discrete state in the system's context during event handling.

The declared discrete state is initialized with either the provided default value or zeros of the correct shape and dtype.

Parameters:

Name Type Description Default
shape ShapeLike

The shape of the discrete state. Defaults to None.

None
default_value Array

The initial value of the discrete state. Defaults to None.

None
dtype DTypeLike

The data type of the discrete state. Defaults to None.

None
as_array bool

If True, treat the default_value as an array-like (cast if necessary). Otherwise, it will be stored as the default state without modification.

True

Raises:

Type Description
AssertionError

If as_array is True and neither shape nor default_value is provided.

Notes

(1) Only one of shape and default_value should be provided. If default_value is provided, it will be used as the initial value of the continuous state. If shape is provided, the initial value will be a zero vector of the given shape and specified dtype.

(2) Use declare_periodic_update to declare an update event that modifies the discrete state at a recurring interval.

Source code in collimator/framework/leaf_system.py
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def declare_discrete_state(
    self,
    shape: ShapeLike = None,
    default_value: Array | Parameter = None,
    dtype: DTypeLike = None,
    as_array: bool = True,
):
    """Declare a new discrete state component for the system.

    The discrete state is a component of the system's state that can be updated
    at specific events, such as zero-crossings or periodic updates. Multiple
    discrete states can be declared, and each is associated with a unique index.
    The index is used to access and update the corresponding discrete state in
    the system's context during event handling.

    The declared discrete state is initialized with either the provided default
    value or zeros of the correct shape and dtype.

    Args:
        shape (ShapeLike, optional):
            The shape of the discrete state. Defaults to None.
        default_value (Array, optional):
            The initial value of the discrete state. Defaults to None.
        dtype (DTypeLike, optional):
            The data type of the discrete state. Defaults to None.
        as_array (bool, optional):
            If True, treat the default_value as an array-like (cast if necessary).
            Otherwise, it will be stored as the default state without modification.

    Raises:
        AssertionError:
            If as_array is True and neither shape nor default_value is provided.

    Notes:
        (1) Only one of `shape` and `default_value` should be provided. If
        `default_value` is provided, it will be used as the initial value of the
        continuous state. If `shape` is provided, the initial value will be a
        zero vector of the given shape and specified dtype.

        (2) Use `declare_periodic_update` to declare an update event that
        modifies the discrete state at a recurring interval.
    """
    if as_array:
        default_value = utils.make_array(default_value, dtype=dtype, shape=shape)

    # Tree-map the default value to ensure that it is an array-like with the
    # correct shape and dtype. This is necessary because the default value
    # may be a list, tuple, or other PyTree-structured object.
    default_value = tree_util.tree_map(cnp.asarray, default_value)

    self._default_discrete_state = default_value

declare_mode_output(name=None)

Declare a mode output port in the system.

This method creates a new block-level output port which returns the component of the system's state corresponding to the discrete "mode" or "stage".

Parameters:

Name Type Description Default
name str

The name of the output port. Defaults to None.

None

Returns:

Name Type Description
int int

The index of the declared mode output port.

Source code in collimator/framework/leaf_system.py
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
def declare_mode_output(self, name: str = None) -> int:
    """Declare a mode output port in the system.

    This method creates a new block-level output port which returns the component
    of the system's state corresponding to the discrete "mode" or "stage".

    Args:
        name (str, optional):
            The name of the output port. Defaults to None.

    Returns:
        int:
            The index of the declared mode output port.
    """

    def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
        return state.mode

    self._mode_output_port_idx = self.declare_output_port(
        _callback,
        name=name,
        prerequisites_of_calc=[DependencyTicket.mode],
        default_value=self._default_mode,
        requires_inputs=False,
    )

    return self._mode_output_port_idx

declare_output_port(callback=None, period=None, offset=0.0, name=None, prerequisites_of_calc=None, default_value=None, requires_inputs=True)

Declare an output port in the LeafSystem.

This method accepts a callback function with the block-level signature callback(time, state, *inputs, **parameters) -> value and wraps it to the signature expected by SystemBase.declare_output_port: callback(context) -> value

Parameters:

Name Type Description Default
callback Callable

The callback function defining the output port.

None
period float

If not None, the port will act as a "sample-and-hold", with the callback function used to define a periodic update event that refreshes the value that will be returned by the port. Typically this should match the update period of some associated update event in the system. Defaults to None.

None
offset float

The offset of the periodic update event. Defaults to 0.0. Will be ignored unless period is not None.

0.0
name str

The name of the output port. Defaults to None.

None
default_value Array

The default value of the output port, if known. Defaults to None.

None
requires_inputs bool | list[int]

If True, the callback will eval input ports to gather input values. This will add a bit to compile time, so setting to False where possible is recommended. Can also be specified as a list of integer port indices. Defaults to True (collect all inputs).

True
prerequisites_of_calc List[DependencyTicket]

The dependency tickets for the output port computation. Defaults to None, in which case the assumption is a dependency on either (nothing) if requires_inputs is False otherwise (inputs).

None

Returns:

Name Type Description
int int

The index of the declared output port.

Source code in collimator/framework/leaf_system.py
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
def declare_output_port(
    self,
    callback: Callable = None,
    period: float = None,
    offset: float = 0.0,
    name: str = None,
    prerequisites_of_calc: List[DependencyTicket] = None,
    default_value: Array = None,
    requires_inputs: bool | list[int] = True,
) -> int:
    """Declare an output port in the LeafSystem.

    This method accepts a callback function with the block-level signature
        `callback(time, state, *inputs, **parameters) -> value`
    and wraps it to the signature expected by SystemBase.declare_output_port:
        `callback(context) -> value`

    Args:
        callback (Callable):
            The callback function defining the output port.
        period (float, optional):
            If not None, the port will act as a "sample-and-hold", with the
            callback function used to define a periodic update event that refreshes
            the value that will be returned by the port. Typically this should
            match the update period of some associated update event in the system.
            Defaults to None.
        offset (float, optional):
            The offset of the periodic update event. Defaults to 0.0.  Will be ignored
            unless `period` is not None.
        name (str, optional):
            The name of the output port. Defaults to None.
        default_value (Array, optional):
            The default value of the output port, if known. Defaults to None.
        requires_inputs (bool | list[int], optional):
            If True, the callback will eval input ports to gather input values.
            This will add a bit to compile time, so setting to False where possible
            is recommended. Can also be specified as a list of integer port indices.
            Defaults to True (collect all inputs).
        prerequisites_of_calc (List[DependencyTicket], optional):
            The dependency tickets for the output port computation.  Defaults to
            None, in which case the assumption is a dependency on either (nothing)
            if `requires_inputs` is False otherwise (inputs).

    Returns:
        int: The index of the declared output port.
    """

    if default_value is not None:
        default_value = cnp.array(default_value)

    cache_index = None
    if period is not None:
        # The output port will be of "sample-and-hold" type, so we have to declare a
        # periodic event to update the value.  The callback will be used to define the
        # update event, and the output callback will simply return the stored value.

        # This is the index that this port value will have in state.cache
        cache_index = len(self._default_cache)
        self._default_cache.append(default_value)

    output_port_idx = super().declare_output_port(
        callback, name=name, cache_index=cache_index
    )

    self.configure_output_port(
        output_port_idx,
        callback,
        period=period,
        offset=offset,
        prerequisites_of_calc=prerequisites_of_calc,
        default_value=default_value,
        requires_inputs=requires_inputs,
    )

    return output_port_idx

declare_zero_crossing(guard, reset_map=None, start_mode=None, end_mode=None, direction='crosses_zero', terminal=False, name=None, enable_tracing=None)

Declare an event triggered by a zero-crossing of a guard function.

Optionally, the system can also transition between discrete modes If start_mode and end_mode are specified, the system will transition from start_mode to end_mode when the event is triggered according to guard. This event will be active conditionally on state.mode == start_mode and when triggered will result in applying the reset map. In addition, the mode will be updated to end_mode.

If start_mode and end_mode are not specified, the event will always be active and will not result in a mode transition.

The guard function should have the signature

guard(time, state, *inputs, **parameters) -> float

and the reset map should have the signature of an unrestricted update

reset_map(time, state, *inputs, **parameters) -> state

Parameters:

Name Type Description Default
guard Callable

The guard function which triggers updates on zero crossing.

required
reset_map Callable

The reset map which is applied when the event is triggered. If None (default), no reset is applied.

None
start_mode int

The mode or stage of the system in which the guard will be actively monitored. If None (default), the event will always be active.

None
end_mode int

The mode or stage of the system to which the system will transition when the event is triggered. If start_mode is None, this is ignored. Otherwise it must be specified, though it can be the same as start_mode.

None
direction str

The direction of the zero crossing. Options are "crosses_zero" (default), "positive_then_non_positive", "negative_then_non_negative", and "edge_detection". All except edge detection operate on continuous signals; edge detection operates on boolean signals and looks for a jump from False to True or vice versa.

'crosses_zero'
terminal bool

If True, the event will halt simulation if and when the zero-crossing occurs. If this event is triggered the reset map will still be applied as usual prior to termination. Defaults to False.

False
name str

The name of the event. Defaults to None.

None
enable_tracing bool

If True, enable tracing for this event. Defaults to None.

None
Notes

By default the system state does not have a "mode" component, so in order to declare "state transitions" with non-null start and end modes, the user must first call declare_default_mode to set the default mode to be some integer (initial condition for the system).

Source code in collimator/framework/leaf_system.py
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
def declare_zero_crossing(
    self,
    guard: Callable,
    reset_map: Callable = None,
    start_mode: int = None,
    end_mode: int = None,
    direction: str = "crosses_zero",
    terminal: bool = False,
    name: str = None,
    enable_tracing: bool = None,
):
    """Declare an event triggered by a zero-crossing of a guard function.

    Optionally, the system can also transition between discrete modes
    If `start_mode` and `end_mode` are specified, the system will transition
    from `start_mode` to `end_mode` when the event is triggered according to `guard`.
    This event will be active conditionally on `state.mode == start_mode` and when
    triggered will result in applying the reset map. In addition, the mode will be
    updated to `end_mode`.

    If `start_mode` and `end_mode` are not specified, the event will always be active
    and will not result in a mode transition.

    The guard function should have the signature:
        `guard(time, state, *inputs, **parameters) -> float`

    and the reset map should have the signature of an unrestricted update:
        `reset_map(time, state, *inputs, **parameters) -> state`

    Args:
        guard (Callable):
            The guard function which triggers updates on zero crossing.
        reset_map (Callable, optional):
            The reset map which is applied when the event is triggered. If None
            (default), no reset is applied.
        start_mode (int, optional):
            The mode or stage of the system in which the guard will be
            actively monitored. If None (default), the event will always be
            active.
        end_mode (int, optional):
            The mode or stage of the system to which the system will transition
            when the event is triggered. If start_mode is None, this is ignored.
            Otherwise it _must_ be specified, though it can be the same as
            start_mode.
        direction (str, optional):
            The direction of the zero crossing. Options are "crosses_zero"
            (default), "positive_then_non_positive", "negative_then_non_negative",
            and "edge_detection".  All except edge detection operate on continuous
            signals; edge detection operates on boolean signals and looks for a
            jump from False to True or vice versa.
        terminal (bool, optional):
            If True, the event will halt simulation if and when the zero-crossing
            occurs. If this event is triggered the reset map will still be applied
            as usual prior to termination. Defaults to False.
        name (str, optional):
            The name of the event. Defaults to None.
        enable_tracing (bool, optional):
            If True, enable tracing for this event. Defaults to None.

    Notes:
        By default the system state does not have a "mode" component, so in
        order to declare "state transitions" with non-null start and end modes,
        the user must first call `declare_default_mode` to set the default mode
        to be some integer (initial condition for the system).
    """

    logger.debug(
        f"Declaring transition for {self.name} with guard {guard} and reset map {reset_map}"
    )

    if enable_tracing is None:
        enable_tracing = True

    if start_mode is not None or end_mode is not None:
        assert (
            self._default_mode is not None
        ), "System has no mode: call `declare_default_mode` before transitions."
        assert isinstance(start_mode, int) and isinstance(end_mode, int)

    # Wrap the reset map with a mode update if necessary
    def _reset_and_update_mode(
        time: Scalar, state: LeafState, *inputs, **parameters
    ) -> LeafState:
        if reset_map is not None:
            state = reset_map(time, state, *inputs, **parameters)
        logger.debug(f"Updating mode from {state.mode} to {end_mode}")

        # If the start and end modes are declared, update the mode
        if start_mode is not None:
            logger.debug(f"Updating mode from {state.mode} to {end_mode}")
            state = state.with_mode(end_mode)

        return state

    _wrapped_guard = self.wrap_callback(guard)
    _wrapped_reset = _wrap_reset_map(
        self, _reset_and_update_mode, _wrapped_guard, terminal
    )

    event = ZeroCrossingEvent(
        system_id=self.system_id,
        guard=_wrapped_guard,
        reset_map=_wrapped_reset,
        passthrough=self._passthrough,
        direction=direction,
        is_terminal=terminal,
        name=name,
        event_data=ZeroCrossingEventData(active=True, triggered=False),
        enable_tracing=enable_tracing,
        active_mode=start_mode,
    )

    event_index = len(self._zero_crossing_events)
    self._zero_crossing_events.append(event)

    # Record the transition in the transition map (for debugging or analysis)
    if start_mode is not None:
        if start_mode not in self.transition_map:
            self.transition_map[start_mode] = []
        self.transition_map[start_mode].append((event_index, event))

initialize(**parameters)

Hook for initializing a system. Called during context creation.

If the parameters are instances of Parameter, they will be resolved. If implemented, the function signature should contain all the declared parameters.

This function should not be called directly. It will be called implicitly after init with the resolved parameters.

Source code in collimator/framework/leaf_system.py
205
206
207
208
209
210
211
212
213
214
215
def initialize(self, **parameters):
    """Hook for initializing a system. Called during context creation.

    If the parameters are instances of Parameter, they will be resolved.
    If implemented, the function signature should contain all the declared
    parameters.

    This function should not be called directly. It will be called implicitly
    after __init__ with the resolved parameters.
    """
    pass

reset_default_values(**dynamic_parameters)

This function is used to reset default values for continuous/discrete states, ports and mode based on dynamic parameters. It is called in create_state() and used to reset states in ensemble sims and optimization with the context method with_new_state().

Note that dtypes and shapes can't be changed after initialization because the diagram may already have been jax-compiled. Only values may change.

Source code in collimator/framework/leaf_system.py
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
def reset_default_values(self, **dynamic_parameters):
    """This function is used to reset default values for
    continuous/discrete states, ports and mode based on dynamic parameters.
    It is called in `create_state()` and used to reset states in ensemble sims
    and optimization with the context method `with_new_state()`.

    Note that dtypes and shapes can't be changed after initialization because
    the diagram may already have been jax-compiled. Only values may change.
    """
    pass

wrap_callback(callback, collect_inputs=True)

Wrap an update function to unpack local variables and block inputs.

The callback should have the signature callback(time, state, *inputs, **params) -> result and will be wrapped to have the signature callback(context) -> result, as expected by the event handling logic.

This is used internally for declaration methods like declare_periodic_update so that users can write more intuitive block-level update functions without worrying about the "context", and have them automatically wrapped to have the right interface. It can also be called directly by users to wrap their own update functions, for example to create a callback function for declare_output_port.

The context and state are strictly immutable, so the callback should not attempt to change any values in the context or state. Even in cases where it is impossible to enforce this (e.g. a state component is a list, which is always mutable in Python), the callback should be careful to avoid direct modification of the context or state, which may lead to unexpected behavior or JAX tracer errors.

Parameters:

Name Type Description Default
callback Callable

The (pure) function to be wrapped. See above for expected signature.

required
collect_inputs bool

If True, the callback will eval input ports to gather input values. Normally this should be True, but it can be set to False if the return value depends only on the state but not inputs, for instance. This helps reduce the number of expressions that need to be JIT compiled. Can also be specified as a list of integer port indices. Default is True (collect all inputs).

True

Returns:

Name Type Description
Callable Callable

The wrapped function, with signature callback(context) -> result.

Source code in collimator/framework/leaf_system.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def wrap_callback(
    self, callback: Callable, collect_inputs: bool | list[int] = True
) -> Callable:
    """Wrap an update function to unpack local variables and block inputs.

    The callback should have the signature
    `callback(time, state, *inputs, **params) -> result`
    and will be wrapped to have the signature `callback(context) -> result`,
    as expected by the event handling logic.

    This is used internally for declaration methods like
    `declare_periodic_update` so that users can write more intuitive
    block-level update functions without worrying about the "context", and have
    them automatically wrapped to have the right interface.  It can also be
    called directly by users to wrap their own update functions, for example to
    create a callback function for `declare_output_port`.

    The context and state are strictly immutable, so the callback should not
    attempt to change any values in the context or state.  Even in cases where
    it is impossible to _enforce_ this (e.g. a state component is a list, which
    is always mutable in Python), the callback should be careful to avoid direct
    modification of the context or state, which may lead to unexpected behavior
    or JAX tracer errors.

    Args:
        callback (Callable):
            The (pure) function to be wrapped. See above for expected signature.
        collect_inputs (bool):
            If True, the callback will eval input ports to gather input values.
            Normally this should be True, but it can be set to False if the
            return value depends only on the state but not inputs, for
            instance. This helps reduce the number of expressions that need to
            be JIT compiled. Can also be specified as a list of integer port indices.
            Default is True (collect all inputs).

    Returns:
        Callable:
            The wrapped function, with signature `callback(context) -> result`.
    """

    def _wrapped_callback(context: ContextBase) -> LeafStateComponent:
        if isinstance(collect_inputs, bool):
            # If port_indices is None, all inputs will be returned
            port_indices = None if collect_inputs else []
        else:
            port_indices = collect_inputs  # List of specific ports to get

        inputs = self.collect_inputs(context, port_indices)
        leaf_context: LeafContext = context[self.system_id]

        leaf_state = leaf_context.state
        params = leaf_context.parameters
        return callback(context.time, leaf_state, *inputs, **params)

    return _wrapped_callback

Parameter dataclass

Source code in collimator/framework/parameter.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
@dataclasses.dataclass
class Parameter:
    value: Union[ParameterExpr, "Parameter", ArrayLike, str, tuple]

    # shape & dtype are set at init time when constructing the parameter,
    # they are not necessarily the actual value's shape and dtype
    dtype: DTypeLike = None
    shape: ShapeLike = None
    as_array: bool = False

    # name is used by reference submodels, model parameters and init script
    # variables so that they can be referred to in other fields
    # (we need this for serialization).
    name: str = None

    # For complex parameter values, we can specify a Python expression as string
    # This is useful for expressions like "np.eye(p)" where p is a parameter.
    is_python_expr: bool = False
    py_namespace: dict = None

    is_static: bool = False  # TODO: staticness should be propagated to dependents
    system: "SystemBase" = None

    def get(self):
        value = ParameterCache.get(self)
        if self.as_array and not isinstance(value, Array):
            value = utils.make_array(value, self.dtype, self.shape)
        return value

    def set(self, value: Union["Parameter", ArrayLike, str, tuple]):
        ParameterCache.replace(self, value)

    @property
    def static_dependents(self):
        return ParameterCache.static_dependents(self)

    @property
    def is_dirty(self):
        return ParameterCache.__is_dirty__[self]

    @classmethod
    def unwrap(cls, value):
        """Get the underlying value of raw arrays and Parameter objects alike."""
        if value is None:
            return None
        if isinstance(value, (Array, bool, int, float, complex)):
            return value
        if isinstance(value, (np.ndarray, np.number)):
            if np.issubdtype(value.dtype, np.number):
                return value
            if value.shape == ():
                return Parameter.unwrap(value.item())
            return Parameter(value).get()
        if isinstance(value, Parameter):
            return value.get()
        if isinstance(value, list):
            return [cls.unwrap(val) for val in value]
        if isinstance(value, tuple):
            return tuple(cls.unwrap(val) for val in value)
        if isinstance(value, dict):
            return {key: cls.unwrap(val) for key, val in value.items()}
        # Fallback for unhandled types: forward to __compute__
        return Parameter(value).get()

    def __post_init__(self):
        ParameterCache.__dependents__[self] = set()

        if isinstance(self.value, Parameter):
            ParameterCache.add_dependent(self.value, self)
        if isinstance(self.value, ParameterExpr):
            for val in self.value:
                if isinstance(val, Parameter):
                    ParameterCache.add_dependent(val, self)
        if isinstance(self.value, (list, tuple)):
            _add_dependents(self.value, self)

        _record_parameter_creation(self)

    def __add__(self, other):
        return _op(Ops.ADD, self, other)

    def __radd__(self, other):
        return _op(Ops.ADD, other, self)

    def __sub__(self, other):
        return _op(Ops.SUB, self, other)

    def __rsub__(self, other):
        return _op(Ops.SUB, other, self)

    def __mul__(self, other):
        return _op(Ops.MUL, self, other)

    def __rmul__(self, other):
        return _op(Ops.MUL, other, self)

    def __truediv__(self, other):
        return _op(Ops.DIV, self, other)

    def __rtruediv__(self, other):
        return _op(Ops.DIV, other, self)

    def __floordiv__(self, other):
        return _op(Ops.FLOORDIV, self, other)

    def __rfloordiv__(self, other):
        return _op(Ops.FLOORDIV, other, self)

    def __mod__(self, other):
        return _op(Ops.MOD, self, other)

    def __rmod__(self, other):
        return _op(Ops.MOD, other, self)

    def __pow__(self, other):
        return _op(Ops.POW, self, other)

    def __rpow__(self, other):
        return _op(Ops.POW, other, self)

    def __neg__(self):
        p = Parameter(value=ParameterExpr([Ops.NEG, self]))
        ParameterCache.add_dependent(self, p)
        return p

    def __pos__(self):
        p = Parameter(value=ParameterExpr([Ops.POS, self]))
        ParameterCache.add_dependent(self, p)
        return p

    def __abs__(self):
        p = Parameter(value=ParameterExpr([Ops.ABS, self]))
        ParameterCache.add_dependent(self, p)
        return p

    def __eq__(self, other):
        return _op(Ops.EQ, self, other)

    def __ne__(self, other):
        return _op(Ops.NE, self, other)

    def __lt__(self, other):
        return _op(Ops.LT, self, other)

    def __le__(self, other):
        return _op(Ops.LE, self, other)

    def __gt__(self, other):
        return _op(Ops.GT, self, other)

    def __ge__(self, other):
        return _op(Ops.GE, self, other)

    def __del__(self):
        ParameterCache.remove(self)

    def __hash__(self):
        return id(self)

    def __str__(self):
        # Calling str() on a Parameter object is confusing. What's the intent?
        # 1. Serializing to a valid Python expression?
        # 2. Is it for logs? For debugging?
        # 3. Is it part of building a wider expression (like a list of parameters)?
        # 4. Evaluating the actual value of a string parameter?
        # Here, we support 2 & 4. We'll likely have to change this when we want support
        # for non-literal string parameters in the UI.

        expr, _ = self.value_as_api_param(
            allow_param_name=True,
            allow_string_literal=True,
        )
        return expr

    def __matmul__(self, other):
        return _op(Ops.MATMUL, self, other)

    def __int__(self):
        if self.dtype is not None:
            return self.dtype(self.get())
        return int(self.get())

    def __float__(self):
        if self.dtype is not None:
            return self.dtype(self.get())
        return float(self.get())

    # FIXME: this is not working as expected - it will break some tests
    # def __bool__(self):
    #     return bool(self.get())

    def __complex__(self):
        return complex(self.get())

    def value_as_api_param(
        self, allow_param_name=True, allow_string_literal=True
    ) -> tuple[str, bool]:
        """Returns an API-compatible expression[1] that defines this parameter

        What we return depends on the caller's context, since it depends on
        whether we are serializing for a model, submodel or block parameter.

        The boolean is the value of 'is_string' (means "string literal" or
        "do not call eval").

        [1] The returned string can be serialized to JSON, but it is not an
            already escaped JSON string!

        Args:
            allow_param_name: Set to false for (sub)model parameters. Optional.
                If true, and the value is defined by a name, just the name will
                be returned.
            allow_string_literal: Set to false for (sub)model parameters. Optional.
                If true, and the value is a string, then the string will be
                returned and 'is_string' will be returned as True.
        """
        if self.name is not None and allow_param_name:
            return self.name, False

        if self.is_python_expr and isinstance(self.value, str):
            return self.value, False

        if allow_string_literal and isinstance(self.value, str):
            return self.value, True

        return _value_as_str(self.value), False

    def __repr__(self):
        ex, _ = self.value_as_api_param(allow_string_literal=False)
        if len(ex) > 100:
            ex = ex[:50] + "..." + ex[-50:]

        return (
            f"Parameter(name={self.name}, value={ex}, "
            f"value_type={type(self.value).__name__}, "
            f"is_static={self.is_static}, "
            f"is_python_expr={self.is_python_expr}, "
            f"id={self.__hash__()}, "
            f"system={self.system.name if self.system is not None else None})"
            ")"
        )

unwrap(value) classmethod

Get the underlying value of raw arrays and Parameter objects alike.

Source code in collimator/framework/parameter.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
@classmethod
def unwrap(cls, value):
    """Get the underlying value of raw arrays and Parameter objects alike."""
    if value is None:
        return None
    if isinstance(value, (Array, bool, int, float, complex)):
        return value
    if isinstance(value, (np.ndarray, np.number)):
        if np.issubdtype(value.dtype, np.number):
            return value
        if value.shape == ():
            return Parameter.unwrap(value.item())
        return Parameter(value).get()
    if isinstance(value, Parameter):
        return value.get()
    if isinstance(value, list):
        return [cls.unwrap(val) for val in value]
    if isinstance(value, tuple):
        return tuple(cls.unwrap(val) for val in value)
    if isinstance(value, dict):
        return {key: cls.unwrap(val) for key, val in value.items()}
    # Fallback for unhandled types: forward to __compute__
    return Parameter(value).get()

value_as_api_param(allow_param_name=True, allow_string_literal=True)

Returns an API-compatible expression[1] that defines this parameter

What we return depends on the caller's context, since it depends on whether we are serializing for a model, submodel or block parameter.

The boolean is the value of 'is_string' (means "string literal" or "do not call eval").

[1] The returned string can be serialized to JSON, but it is not an already escaped JSON string!

Parameters:

Name Type Description Default
allow_param_name

Set to false for (sub)model parameters. Optional. If true, and the value is defined by a name, just the name will be returned.

True
allow_string_literal

Set to false for (sub)model parameters. Optional. If true, and the value is a string, then the string will be returned and 'is_string' will be returned as True.

True
Source code in collimator/framework/parameter.py
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
def value_as_api_param(
    self, allow_param_name=True, allow_string_literal=True
) -> tuple[str, bool]:
    """Returns an API-compatible expression[1] that defines this parameter

    What we return depends on the caller's context, since it depends on
    whether we are serializing for a model, submodel or block parameter.

    The boolean is the value of 'is_string' (means "string literal" or
    "do not call eval").

    [1] The returned string can be serialized to JSON, but it is not an
        already escaped JSON string!

    Args:
        allow_param_name: Set to false for (sub)model parameters. Optional.
            If true, and the value is defined by a name, just the name will
            be returned.
        allow_string_literal: Set to false for (sub)model parameters. Optional.
            If true, and the value is a string, then the string will be
            returned and 'is_string' will be returned as True.
    """
    if self.name is not None and allow_param_name:
        return self.name, False

    if self.is_python_expr and isinstance(self.value, str):
        return self.value, False

    if allow_string_literal and isinstance(self.value, str):
        return self.value, True

    return _value_as_str(self.value), False

ParameterCache

Source code in collimator/framework/parameter.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
class ParameterCache:
    __dependents__: dict["Parameter", set["Parameter"]] = {}
    __cache__: dict["Parameter", ArrayLike] = {}
    __is_dirty__ = defaultdict(lambda: True)

    @classmethod
    def get(cls, param: "Parameter") -> ArrayLike:
        if cls.__is_dirty__[param]:
            cls.__cache__[param] = cls.__compute__(param)
            cls.__is_dirty__[param] = False

        return cls.__cache__[param]

    @classmethod
    def replace(cls, param: "Parameter", value: ArrayLike):
        # TODO: update dependencies of this parameter
        param.value = value
        cls.__invalidate__(param)

    @classmethod
    def remove(cls, param: "Parameter"):
        for dependents in cls.__dependents__.values():
            if param in dependents:
                dependents.remove(param)

        if param in cls.__dependents__:
            del cls.__dependents__[param]
        if param in cls.__cache__:
            del cls.__cache__[param]
        if param in cls.__is_dirty__:
            del cls.__is_dirty__[param]

    @classmethod
    def add_dependent(cls, param: "Parameter", dependent: "Parameter"):
        # Mark 'dependent' as having a dependency on 'param', that is,
        # 'param' is built as an expression that involves 'dependent'.
        cls.__dependents__[param].add(dependent)

    @classmethod
    def get_dependents(cls, param: "Parameter"):
        return cls.__dependents__[param]

    @classmethod
    def print_dependents(cls, param: "Parameter", indent=0):
        """Prints the dependents tree of a parameter"""
        indent_str = "|" + "--" * indent if indent > 0 else ""
        print(indent_str + repr(param))
        for dependent in cls.__dependents__[param]:
            cls.print_dependents(dependent, indent + 1)

    @classmethod
    def static_dependents(cls, param: "Parameter"):
        dependents = set()
        for dependent in cls.__dependents__[param]:
            if dependent.is_static:
                dependents.add(dependent)
            dependents |= cls.static_dependents(dependent)
        return dependents

    @classmethod
    def __invalidate__(cls, param: "Parameter"):
        cls.__cache__[param] = None
        cls.__is_dirty__[param] = True
        for dependent in cls.__dependents__[param]:
            cls.__invalidate__(dependent)

    @classmethod
    def __compute__(cls, param: "Parameter"):
        if isinstance(param.value, ParameterExpr):
            acc = None
            right_value = None
            op = None
            i = 0

            while i < len(param.value):
                val = param.value[i]

                if isinstance(val, Parameter):
                    right_value = val.get()
                elif isinstance(val, ArrayLikeTypes):
                    right_value = val
                elif isinstance(val, Ops):
                    if val in (Ops.NEG, Ops.POS, Ops.ABS):
                        if i + 1 >= len(param.value):
                            raise ParameterError(
                                param, message="Invalid parameter value"
                            )
                        if isinstance(param.value[i + 1], Parameter):
                            right_value = __OPS_FN__[val](param.value[i + 1].get())
                        elif isinstance(param.value[i + 1], ArrayLikeTypes):
                            right_value = __OPS_FN__[val](param.value[i + 1])
                        else:
                            raise ParameterError(
                                param,
                                message=f"Invalid value in parameter list: {param.value[i + 1]} of type {type(param.value[i + 1])}",
                            )
                        i += 1
                    else:
                        op = val
                else:
                    raise ParameterError(
                        param,
                        message=f"Invalid value in parameter list: {val} of type {type(val)}",
                    )

                if acc is not None and right_value is not None and op is not None:
                    acc = __OPS_FN__[op](acc, right_value)
                    op = None
                    right_value = None
                elif right_value is not None:
                    acc = right_value
                    right_value = None
                i += 1

            if acc is not None:
                return acc
            if right_value is not None:
                return right_value
            raise ParameterError(param, message="Invalid parameter value")

        if isinstance(param.value, Parameter):
            return cls.__compute__(param.value)

        if isinstance(param.value, tuple):
            t = _compute_list(param.value, is_tuple=True)
            return t

        if isinstance(param.value, list):
            t = _compute_list(param.value, is_tuple=False)
            return t

        if isinstance(param.value, dict):
            return {key: Parameter.unwrap(val) for key, val in param.value.items()}

        if isinstance(param.value, np.ndarray):
            vals = _resolve_array_param_value(param)
            return np.array(vals, dtype=param.value.dtype)

        if isinstance(param.value, Array):
            vals = _resolve_array_param_value(param)
            if param.value.weak_type:
                return jnp.array(vals)
            return jnp.array(vals, dtype=param.value.dtype)

        if isinstance(param.value, np.number):
            if isinstance(param.value.item(), Parameter):
                return type(param.value)(cls.__compute__(param.value.item()))
            return param.value

        if isinstance(param.value, str) and param.is_python_expr:
            _, resolved_parameters = resolve_parameters(param.value, param.py_namespace)
            return eval(
                param.value,
                param.py_namespace,
                {**param.py_namespace, **resolved_parameters},
            )

        return param.value

print_dependents(param, indent=0) classmethod

Prints the dependents tree of a parameter

Source code in collimator/framework/parameter.py
421
422
423
424
425
426
427
@classmethod
def print_dependents(cls, param: "Parameter", indent=0):
    """Prints the dependents tree of a parameter"""
    indent_str = "|" + "--" * indent if indent > 0 else ""
    print(indent_str + repr(param))
    for dependent in cls.__dependents__[param]:
        cls.print_dependents(dependent, indent + 1)

ShapeMismatchError

Bases: StaticError

Block parameters or input/outputs have mismatched shapes.

Source code in collimator/framework/error.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class ShapeMismatchError(StaticError):
    """Block parameters or input/outputs have mismatched shapes."""

    def __init__(self, expected_shape=None, actual_shape=None, **kwargs):
        super().__init__(**kwargs)
        self.expected_shape = expected_shape
        self.actual_shape = actual_shape

    def __str__(self):
        if self.expected_shape or self.actual_shape:
            return (
                f"Shape mismatch: expected {self.expected_shape}, "
                f"got {self.actual_shape}" + self._context_info()
            )
        return f"Shape mismatch{self._context_info()}"

StaticError

Bases: CollimatorError

Wraps a Python exception to record the offending block id. The original exception is found in the 'cause' field.

See collimator.framework.context_factory._check_types for use.

This is called 'static' (as opposed to say 'runtime') meaning this is for wrapping errors detected prior to running a simulation.

Source code in collimator/framework/error.py
161
162
163
164
165
166
167
168
169
170
class StaticError(CollimatorError):
    """Wraps a Python exception to record the offending block id. The original
    exception is found in the '__cause__' field.

    See collimator.framework.context_factory._check_types for use.

    This is called 'static' (as opposed to say 'runtime') meaning this is for
    wrapping errors detected prior to running a simulation."""

    pass

SystemBase dataclass

Basic building block for simulation in collimator.

NOTE: Type hints in SystemBase indicate the union between what would be returned by a LeafSystem and a Diagram. See type hints of the subclasses for the specific argument and return types.

Source code in collimator/framework/system_base.py
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
@dataclasses.dataclass
class SystemBase:
    """Basic building block for simulation in collimator.

    NOTE: Type hints in SystemBase indicate the union between what would be returned
    by a LeafSystem and a Diagram. See type hints of the subclasses for the specific
    argument and return types.
    """

    # Generated unique ID for this system
    system_id: Hashable = dataclasses.field(default_factory=next_system_id, init=False)
    name: str = None  # Human-readable name for this system (optional but never None)
    ui_id: str = None  # UUID of the block when loaded from JSON (optional, may be None)

    # Immediate parent of the current system (can only be a Diagram).
    # If None, _this_ is the root system.
    parent: Diagram = None

    def __post_init__(self):
        if self.name is None:
            self.name = f"{type(self).__name__}_{self.system_id}_"

        # All "cache sources" for this system. Typically these will correspond to
        # input ports, output ports, time derivative calculations, and any custom
        # "cached data" declared by the user (e.g. see ModelicaFMU block).
        self.callbacks: List[SystemCallback] = []

        # Index into SystemCallbacks for each port. For instance, input port `i` can
        # be retrieved by `self.callbacks[self.input_port_indices[i]]`. The
        # `input_ports` and `output_ports` properties give more convenient access.
        self.input_port_indices: List[int] = []
        self.output_port_indices: List[int] = []

        # Override this or set manually to provide a custom characteristic time scale
        # for the system. At the moment this is only used for zero-crossing isolation
        # in the simulator.
        self.characteristic_time = 1.0

        # A dependency graph for the system, mapping prerequisites of each calculation.
        # `None` indicates that the dependency graph has not been constructed yet.  If
        # accessed via the `dependency_graph` property, it will be constructed
        # automatically as necessary.
        self._dependency_graph: DependencyGraph = None

        # Static parameters are not jax-traceable
        self._static_parameters: dict[str, Parameter] = {}

        # If not empty, this defines the shape and data type of the numeric parameters
        # in the LeafSystem. This value will be used to initialize the context, so it
        # will also serve as the initial value unless explicitly overridden. In the
        # simplest cases, parameters could be stored as attributes of the LeafSystem,
        # but declaring them has the advantage of moving the values to the context,
        # allowing them to be traced by JAX rather than stored as static data. This
        # means they can be differentiated, vmapped, or otherwise modified without
        # re-compiling the simulation.
        self._dynamic_parameters: dict[str, Array] = {}

        # Map from (input_port, output_port) if that pair is feedthrough
        # `None` indicates that the feedthrough is unknown for this system.
        # This will be computed automatically using the dependency graph
        # during algebraic loop detection unless it is set manually.
        # To manually set feedthrough, either declare this explicitly or
        # override `get_feedthrough`.
        self.feedthrough_pairs: List[Tuple[int, int]] = None

        # Pre-sorted list of all output update events for this system.  This will
        # be created when the associated property is first accessed.  This should
        # only need to be done for the root system.
        self._cache_update_events: EventCollection = None

        # Cached lists of i/o ports SystemCallbacks. Do not read this directly.
        self._cached_input_ports: List[SystemCallback] = []
        self._cached_output_ports: List[SystemCallback] = []

        # Cache mechanism for numpy backend.
        self._basic_output_cache: BasicOutputCache = BasicOutputCache(self)

    def __hash__(self) -> Hashable:
        return hash(self.system_id)

    def pprint(self, output=print, fancy=True) -> str:
        """Pretty-print the system and its hierarchy."""
        output(self._pprint_helper(fancy=fancy).strip())

    def _pprint_helper(self, prefix="", fancy=True) -> str:
        if fancy:
            return pprint_fancy(prefix, self)
        return f"{prefix}|-- {self.name}(id={self.system_id})\n"

    def post_simulation_finalize(self) -> None:
        """Finalize the system after simulation has completed.

        This is only intended for special blocks that need to clean up
        resources and close files."""

    @property
    def static_parameters(self) -> dict[str, Parameter]:
        return self._static_parameters

    @static_parameters.setter
    def static_parameters(self, value):
        self._static_parameters = value

    @property
    def dynamic_parameters(self) -> dict[str, Parameter]:
        return self._dynamic_parameters

    @property
    def parameters(self) -> dict[str, Parameter]:
        return {**self.static_parameters, **self.dynamic_parameters}

    #
    # Simulation interface
    #
    @abc.abstractproperty
    def has_feedthrough_side_effects(self) -> bool:
        """Check if the system includes any feedthrough calls to `io_callback`."""
        # This is a tricky one to explain and is almost always False except for a
        # PythonScript block that is not JAX traced.  Basically, if the output of
        # the system is used as an ODE right-hand-side, will it fail in the case where
        # the ODE solver defines a custom VJP?  This happens in diffrax, so for example
        # if a PythonScript block is used to compute the ODE right-hand-side, it will
        # fail with "Effects not supported in `custom_vjp`"
        pass

    @abc.abstractproperty
    def has_ode_side_effects(self) -> bool:
        """Check if the ODE RHS for the system includes any calls to `io_callback`."""
        # This flag indicates that the system `has_feedthrough_side_effects` AND that
        # signal is used as an ODE right-hand-side.  This is used to determine whether
        # a JAX ODE solver can be used to integrate the system.
        pass

    @abc.abstractproperty
    def has_continuous_state(self) -> bool:
        pass

    @abc.abstractproperty
    def has_discrete_state(self) -> bool:
        pass

    @abc.abstractproperty
    def has_zero_crossing_events(self) -> bool:
        pass

    def eval_time_derivatives(self, context: ContextBase) -> StateComponent:
        """Evaluate the continuous time derivatives for this system.

        Given the _root_ context, evaluate the continuous time derivatives,
        which must have the same PyTree structure as the continuous state.

        In principle, this can be overridden by custom implementations, but
        in general it is preferable to declare continuous states for LeafSystems
        using `declare_continuous_state`, which accepts a callback function
        that will be used to compute the derivatives. For Diagrams, the time
        derivatives are computed automatically using the callback functions for
        all child systems with continuous state.

        Args:
            context (ContextBase): root context of this system

        Returns:
            StateComponent:
                Continuous time derivatives for this system, or None if the system
                has no continuous state.
        """
        return None

    @property
    @abc.abstractmethod
    def mass_matrix(self) -> StateComponent:
        """Mass matrix for this system.

        Returns PyTree-structured data where each leaf is an (n, n) array.
        This is used for implicit integration methods (currently only BDF).
        """
        pass

    @property
    @abc.abstractmethod
    def has_mass_matrix(self) -> bool:
        """Returns True if any component of the system has a nontrivial mass matrix."""
        pass

    @abc.abstractmethod
    def eval_zero_crossing_updates(
        self,
        context: ContextBase,
        events: EventCollection,
    ) -> State:
        """Evaluate reset maps associated with zero-crossing events.

        Args:
            context (ContextBase):
                The context for the system, containing the current state and parameters.
            events (EventCollection):
                The collection of events to be evaluated (for example zero-crossing or
                periodic events for this system).

        Returns:
            State: The complete state with all updates applied.

        Notes:
            (1) Following the Drake definition, "unrestricted" updates are allowed to
            modify any component of the state: continuous, discrete, or mode.  These
            updates are evaluated in the order in which they were declared, so it is
            _possible_ (but should be strictly avoided) for multiple events to modify the
            same state component at the same time.

            Each update computes its results given the _current_ state of the system
            (the "minus" values) and returns the _updated_ state (the "plus" values).
            The update functions cannot access any information about the "plus" values of
            its own state or the state of any other block.  This could change in the future
            but for now it ensures consistency with Drake's discrete semantices:

            More specifically, since all unrestricted updates can modify the entire state,
            any time there are multiple unrestricted updates, the resulting states are
            ALWAYS in conflict.  For example, suppose a system has two unrestricted
            updates, `event1` and `event2`.  At time t_n, `event1` is active and `event2`
            is inactive.  First, `event1` is evaluated, and the state is updated.  Then
            `event2` is evaluated, but the state is not updated.  Which one is valid?
            Obviously, the `event1` return is valid, but how do we communicate this to JAX?
            The situation is more complicated if both `event1` and `event2` happen to be
            active.  In this case the states have to be "merged" somehow.  In the worst
            case, these two will modify the same components of the state in different ways.

            The implementation updates the state in a local copy of the context (since both
            are immutable).  This allows multiple unrestricted updates, but leaves open the
            possibility of multiple active updates modifying the state in conflicting ways.
            This should strictly be avoided by the implementer of the LeafSystem.  If it is
            at all unclear how to do this, it may be better to split the system into
            multiple blocks to be safe.

            (2) The events are evaluated conditionally on being marked "active"
            (indicating that their guard function triggered), so the entire event
            collection can be passed without filtering to active events. This is necessary
            to make the function calls work with JAX tracing, which do not allow for
            variable-sized arguments or returns.
        """
        pass

    def handle_discrete_update(
        self, events: EventCollection, context: ContextBase
    ) -> ContextBase:
        """Compute and apply active discrete updates.

        Given the _root_ context, evaluate the discrete updates, which must have the
        same PyTree structure as the discrete states of this system. This should be
        a pure function, so that it does not modify any aspect of the context in-place
        (even though it is difficult to strictly prevent this in Python).

        This will evaluate the set of events that result from declaring state or output
        update events on systems using `LeafSystem.declare_periodic_update` and
        `LeafSystem.declare_output_port` with an associated periodic update rate.

        This is intended for internal use by the simulator and should not normally need
        to be invoked directly by users. Events are evaluated conditionally on being
        marked "active", so the entire event collection can be passed without filtering
        to active events. This is necessary to make the function calls work with JAX
        tracing, which do not allow for variable-sized arguments or returns.

        For a discrete system updating at a particular rate, the update rule for a
        particular block is:

        ```
        x[n+1] = f(t[n], x[n], u[n])
        y[n]   = g(t[n], x[n], u[n])
        ```

        Additionally, the value y[n] is held constant until the next update from the
        point of view of other continuous-time or asynchronous discrete-time blocks.

        Because each output `y[n]` may in general depend on the input `u[n]` evaluated
        _at the same time_, the composite discrete update function represents a
        system of equations.  However, since algebraic loops are prohibited, the events
        can be ordered and executed sequentially to ensure that the updates are applied
        in the correct order.  This is implemented in
        `SystemBase.sorted_event_callbacks`.

        Multirate systems work in the same way, except that the events are evaluated
        conditionally on whether the current time corresponds to an update time for each
        event.

        Args:
            events (EventCollection): collection of discrete update events
            context (ContextBase): root context for this system

        Returns:
            ContextBase:
                updated context with all active updates applied to the discrete state
        """
        logger.debug(
            f"Handling {events.num_events} discrete update events at t={context.time}"
        )
        if events.has_events:
            # Sequentially apply all updates
            for event in events:
                system_id = event.system_id
                state = event.handle(context)
                local_context = context[system_id].with_state(state)
                context = context.with_subcontext(system_id, local_context)

        return context

    def handle_zero_crossings(
        self, events: EventCollection, context: ContextBase
    ) -> ContextBase:
        """Compute and apply active zero-crossing events.

        This is intended for internal use by the simulator and should not normally need
        to be invoked directly by users. Events are evaluated conditionally on being
        marked "active", so the entire event collection can be passed without filtering
        to active events. This is necessary to make the function calls work with JAX
        tracing, which do not allow for variable-sized arguments or returns.

        Args:
            events (EventCollection): collection of zero-crossing events
            context (ContextBase): root context for this system

        Returns:
            ContextBase: updated context with all active zero-crossing events applied
        """
        logger.debug(
            "Handling %d state transition events at t=%s",
            events.num_events,
            context.time,
        )
        if events.num_events > 0:
            # Compute the updated state for all active events
            state = self.eval_zero_crossing_updates(context, events)

            # Apply the updates
            context = context.with_state(state)
        return context

    # Events that happen at some regular interval
    @property
    def periodic_events(self) -> FlatEventCollection:
        return self.cache_update_events + self.state_update_events

    # Events that update the discrete state of the system
    @abc.abstractproperty
    def state_update_events(self) -> FlatEventCollection:
        pass

    # Events that refresh sample-and-hold outputs of the system
    @property
    def cache_update_events(self) -> FlatEventCollection:
        if self._cache_update_events is None:
            # Sort the output update events and store in the private attribute.
            self._cache_update_events = [
                cb.event for cb in self.sorted_event_callbacks if cb.event is not None
            ]

        return FlatEventCollection(tuple(self._cache_update_events))

    @abc.abstractproperty
    def _flat_callbacks(self) -> List[SystemCallback]:
        """Get a flat list of callbacks for this and all child systems."""
        pass

    @property
    def sorted_event_callbacks(self) -> List[SystemCallback]:
        """Sort and return the event-related callbacks for this system."""
        # Collect all of the callbacks associated with cache update events.  These are
        # SystemCallback objects, so they are all associated with trackers in the
        # dependency graph.  We can use these to sort the events in execution order.
        trackers = sort_trackers([cb.tracker for cb in self._flat_callbacks])

        # Retrieve the callback associated with each tracker.
        return [tracker.cache_source for tracker in trackers]

    # Events that are triggered by a "guard" function and may induce a "reset" map
    @abc.abstractproperty
    def zero_crossing_events(self) -> EventCollection:
        pass

    @abc.abstractmethod
    def determine_active_guards(self, context: ContextBase) -> EventCollection:
        """Determine active guards for zero-crossing events.

        This method is responsible for evaluating and determining which
        zero-crossing events are active based on the current system mode
        and other conditions.  This can be overridden to flag active/inactive
        guards on a block-specific basis, for instance in a StateMachine-type
        block. By default all guards are marked active at this point unless
        the zero-crossing event was declared with a non-default `start_mode`, in
        which case the guard is activated conditionally on the current mode.

        For example, in a system with finite state transitions, where a transition from
        mode A to mode B is triggered by a guard function g_AB and the inverse
        transition is triggered by a guard function g_BA, this function would activate
        g_AB if the system is in mode A and g_BA if the system is in mode B. The other
        guard function would be inactive.  If the zero-crossing event is not associated
        with a starting mode, it is considered to be always active.

        Args:
            context (ContextBase):
                The root context containing the overall state and parameters.

        Returns:
            EventCollection:
                A collection of zero-crossing events with active/inactive status
                updated based on the current system mode and other conditions.
        """
        pass

    #
    # I/O ports
    #
    @property
    def input_ports(self) -> List[InputPort]:
        if len(self._cached_input_ports) != len(self.input_port_indices):
            ports = list(self.callbacks[i] for i in self.input_port_indices)
            self._cached_input_ports.clear()
            self._cached_input_ports.extend(ports)
        return self._cached_input_ports

    def get_input_port(self, name: str) -> tuple[InputPort, int]:
        """Retrieve a specific input port by name."""
        for i, port in enumerate(self.input_ports):
            if port.name == name:
                return port, i
        raise ValueError(
            f"System {self.name} has no input port named {name}. "
            f"Available ports: {list(map(lambda x: x.name,self.input_ports))}"
        )

    @property
    def num_input_ports(self) -> int:
        return len(self.input_port_indices)

    @property
    def output_ports(self) -> List[OutputPort]:
        if len(self._cached_output_ports) != len(self.output_port_indices):
            ports = list(self.callbacks[i] for i in self.output_port_indices)
            self._cached_output_ports.clear()
            self._cached_output_ports.extend(ports)
        return self._cached_output_ports

    def get_output_port(self, name: str) -> OutputPort:
        """Retrieve a specific output port by name."""
        for port in self.output_ports:
            if port.name == name:
                return port
        raise ValueError(f"System {self.name} has no output port named {name}")

    @property
    def num_output_ports(self) -> int:
        return len(self.output_port_indices)

    def eval_input(self, context: ContextBase, port_index: int = 0) -> Array:
        """Get the input for a given port.

        This works by evaluating the callback function associated with the port, which
        will "pull" the upstream output port values.

        Args:
            context (ContextBase): root context for this system
            port_index (int, optional): index into `self.input_ports`, for example
                the value returned by `declare_input_port`. Defaults to 0.

        Returns:
            Array: current input values
        """
        return self.input_ports[port_index].eval(context)

    def collect_inputs(
        self, context: ContextBase, port_indices: list[int] = None
    ) -> List[Array]:
        """Collect all current inputs for this system.

        Args:
            context (ContextBase): root context for this system
            port_indices (List[int], optional): list of input port indices to collect.
                If None (default), will return values from all ports.  Otherwise will
                return a list of length(num_input_ports), where the values are None for
                ports not in the list.

        Returns:
            List[Array]: list of all current input values
        """
        if port_indices is None:
            port_indices = range(self.num_input_ports)

        # Some blocks are hard-coded to have no inputs, so we should make
        # sure that a list full of None is not returned in that case. This
        # happens if the callback signature is (time, state, **parameters)
        # instead of the more general (time, state, *inputs, **parameters)
        if port_indices == []:
            return []

        # Right now this seems to be the best place to add the cache.
        # FIXME: Caching outputs could work better if we knew which ones
        # to target specifically (the more expensive ones).
        inputs = self._basic_output_cache.get(context)

        # If inputs is not None but is not consistent with the requested
        # port_indices, we should recompute the inputs.  This is only an issue
        # when using the NumPy-backend caching.
        if inputs is not None:
            if any(inputs[i] is None for i in port_indices):
                inputs = None

        if inputs is None:
            inputs = []
            for i in range(self.num_input_ports):
                u_i = self.eval_input(context, i) if i in port_indices else None
                inputs.append(u_i)
            self._basic_output_cache.set(context, inputs)

        return inputs

    def invalidate_output_caches(self):
        self._basic_output_cache.invalidate()

    def _eval_input_port(self, context: ContextBase, port_index: int) -> Array:
        """Evaluate an upstream input port given the _root_ context.

        Intended for internal use as a callback function. Users and developers
        should typically call `eval_input` in order to get this information. That
        method will call the callback function associated with the input port,
        which will have a reference to this method.

        Args:
            context (ContextBase): root context for this system
            port_index (int): index of the input port to evaluate on the target system

        Returns:
            Array: current input values
        """
        # A helper function to evaluate an upstream input port given the _root_ context.

        port_locator = self.input_ports[port_index].locator

        if self.parent is None:
            # This is currently the root system.  Typically root input ports should not be evaluated,
            #  but we can get here during subsystem construction (e.g. type inference).  In that case,
            #  we should just defer evaluation and rely on the graph analysis to determine that
            #  everything is connected correctly.
            # See https://collimator.atlassian.net/browse/WC-51.
            # This should not happen during simulation or root context construction.
            logger.debug(
                "    ---> %s is the root system, deferring evaluation of %s[%s]",
                self.name,
                port_locator[0].name,
                port_locator[1],
            )
            raise UpstreamEvalError(port_locator=(self, "in", port_index))

        # The `eval_subsystem_input_port` method is only defined for Diagrams, but
        # the parent system is guaranteed to be a Diagram if this is not the root.
        # If it is the root, it should not have any (un-fixed) input ports.
        return self.parent.eval_subsystem_input_port(context, port_locator)

    #
    # Declaration utilities
    #
    def _next_input_port_name(self, name: str | None = None) -> str:
        """Automatically generate a unique name for the next input port."""
        if name is not None:
            assert name != ""
            return name
        return f"in_{self.num_input_ports}"

    def _next_output_port_name(self, name: str | None = None) -> str:
        """Automatically generate a unique name for the next output port."""
        if name is not None:
            assert name != ""
            return name
        return f"out_{self.num_output_ports}"

    def declare_input_port(
        self,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
    ) -> int:
        """Add an input port to the system.

        Returns the corresponding index into the system input_port_indices list
        Note that this is different from the callbacks index - typically it
        will make more sense to retrieve via system.input_ports[port_index], but

        Args:
            name (str, optional): name of the new port. Defaults to None, which will
                use the default naming scheme for the system (e.g. "u_0")
            prerequisites_of_calc (List[DependencyTicket], optional): list of
                dependencies for the callback function. Defaults to None.

        Returns:
            int: port index of the newly created port in `input_ports`
        """
        port_index = self.num_input_ports
        port_name = self._next_input_port_name(name)

        for port in self.input_ports:
            assert (
                port.name != port_name
            ), f"System {self.name} already has an input port named {port.name}"

        def _callback(context: ContextBase) -> Array:
            # Given the root context, evaluate the input port using the helper function
            return self._eval_input_port(context, port_index)

        callback_index = len(self.callbacks)
        port = InputPort(
            callback=_callback,
            system=self,
            callback_index=callback_index,
            name=port_name,
            index=port_index,
            prerequisites_of_calc=prerequisites_of_calc,
        )

        assert isinstance(port, InputPort)
        assert port.system is self
        assert port.name != ""

        # Check that name is unique
        for p in self.input_ports:
            assert (
                p.name != port.name
            ), f"System {self.name} already has an input port named {port.name}"

        self.input_port_indices.append(callback_index)
        self.callbacks.append(port)
        self._cached_input_ports.clear()

        return port_index

    def declare_output_port(
        self,
        callback: Callable,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array = None,
        event: DiscreteUpdateEvent = None,
        cache_index: int = None,
    ) -> int:
        """Add an output port to the system.

        This output port could represent any function of the context available to
        the system, so a callback function is required.  This function should have
        the form
            `callback(context: ContextBase) -> Array`
        SystemBase implementations have some specific convenience wrappers, e.g.:
            `LeafSystem.declare_continuous_state_output`
            `Diagram.export_output`

        Common cases are:
        - Feedthrough blocks: gather inputs and return some function of the
            inputs (e.g. a gain)
        - Stateful blocks: use LeafSystem.declare_(...)_state_output_port to
            return the value of a particular state
        - Diagrams: create and export a diagram-level port to the parent system using
            the callback function associated with the system-level port

        Returns the corresponding index into the system output_port_indices list
        Note that this is different from the callbacks index - typically it
        will make more sense to retrieve via system.output_ports[port_index].

        Args:
            callback (Callable, optional): computes the value of the output port given
                the root context.
            name (str, optional): name of the new port. Defaults to None, which will
                use the default naming scheme for the system (e.g. "y_0")
            prerequisites_of_calc (List[DependencyTicket], optional): list of
                dependencies for the callback function. Defaults to None, which will
                use the default dependencies for the system (all sources).  This may
                conservatively flag the system as having algebraic loops, so it is
                better to be specific here when possible.  This is done automatically
                in the wrapper functions like `LeafSystem.declare_(...)_output_port`
            default_value (Array, optional): A default array-like value used to seed
                the context and perform type inference, when this is known up front.
                Defaults to None, which will use information propagation through the
                graph along with type promotion to determine an appropriate value.
            event (DiscreteUpdateEvent, optional): A discrete update event associated
                with this output port that will periodically refresh the value that
                will be returned by the callback function. This makes the port act as
                a sample-and-hold rather than a direct function evaluation.
            cache_index (int, optional): Index into the cache state component
                corresponding to the output port result, if the output port is of
                periodically-updated sample-and-hold type.

        Returns:
            int: port index of the newly created port
        """
        port_index = self.num_output_ports
        port_name = self._next_output_port_name(name)

        for port in self.output_ports:
            assert (
                port.name != port_name
            ), f"System {self.name} already has an output port named {port.name}"

        if prerequisites_of_calc is None:
            prerequisites_of_calc = [DependencyTicket.all_sources]

        callback_index = len(self.callbacks)
        port = OutputPort(
            callback,
            system=self,
            callback_index=callback_index,
            name=port_name,
            index=port_index,
            prerequisites_of_calc=prerequisites_of_calc,
            default_value=default_value,
            event=event,
            cache_index=cache_index,
        )

        assert isinstance(port, OutputPort)
        assert port.system is self
        assert port.name != ""

        # Check that name is unique
        for p in self.output_ports:
            assert (
                p.name != port.name
            ), f"System {self.name} already has an output port named {port.name}"

        logger.debug("Adding output port %s to %s", port, self.name)
        self.output_port_indices.append(callback_index)
        self.callbacks.append(port)
        self._cached_output_ports.clear()

        logger.debug(
            "    ---> %s now has %s output ports: %s",
            self.name,
            len(self.output_ports),
            self.output_ports,
        )
        logger.debug(
            "    ---> %s now has %s cache sources: %s",
            self.name,
            len(self.callbacks),
            self.callbacks,
        )

        return port_index

    def configure_output_port(
        self,
        port_index: int,
        callback: Callable,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array = None,
        event: DiscreteUpdateEvent = None,
        cache_index: int = None,
    ):
        """Configure an output port of the system.

        See `declare_output_port` for a description of the arguments.

        Args:
            port_index (int): index of the output port to configure

        Returns:
            None
        """

        if prerequisites_of_calc is None:
            prerequisites_of_calc = [DependencyTicket.all_sources]

        port = self.output_ports[port_index]
        port.port_index = port_index
        port._callback = callback
        port.prerequisites_of_calc = prerequisites_of_calc
        port.default_value = default_value
        port.event = event
        port.cache_index = cache_index
        self.callbacks[port.callback_index] = port
        self._cached_output_ports.clear()

        logger.debug(
            "    ---> %s now has %s output ports: %s",
            self.name,
            len(self.output_ports),
            self.output_ports,
        )
        logger.debug(
            "    ---> %s now has %s cache sources: %s",
            self.name,
            len(self.callbacks),
            self.callbacks,
        )

    @abc.abstractmethod
    def get_feedthrough(self) -> List[Tuple[int, int]]:
        """Determine pairs of direct feedthrough ports for this system.

        By default, the algorithm relies on the dependency tracking system to determine
        feedthrough, but this can be overridden by implementing this method directly
        in a subclass, for instance if the automatic dependency tracking is too
        conservative in determining feedthrough.

        Returns:
            List[Tuple[int, int]]:
                A list of tuples (u, v) indicating that output port v has a direct
                dependency on input port u, resulting in a feedthrough path in the system.
                The indices u and v correspond to the indices of the input and output
                ports in the system's input and output port lists.
        """
        pass

    #
    # Initialization
    #
    def create_context(self, **kwargs) -> ContextBase:
        """Create a new context for this system.

        The context will contain all variable information used in
        simulation/analysis/optimization, such as state and parameters.

        Returns:
            ContextBase: new context for this system
        """
        return self.context_factory(**kwargs)

    def check_types(self, context: ContextBase, error_collector: ErrorCollector = None):
        """Perform any system-specific static analysis."""
        pass

    @abc.abstractproperty
    def context_factory(self) -> ContextFactory:
        """Factory object for creating contexts for this system.

        Should not be called directly - use `system.create_context` instead.
        """
        pass

    @property
    def dependency_graph(self) -> DependencyGraph:
        """Retrieve (or create if necessary) the dependency graph for this system."""
        return self._dependency_graph

    @abc.abstractproperty
    def dependency_graph_factory(self) -> DependencyGraphFactory:
        """Factory object for creating dependency graphs for this system.

        Should not be called directly - use `system.create_dependency_graph` instead.
        """
        pass

    def create_dependency_graph(self):
        """Create a dependency graph for this system."""
        self._dependency_graph = self.dependency_graph_factory()

    def initialize_static_data(self, context: ContextBase) -> ContextBase:
        """Initialize any context data that has to be done after context creation.

        Use this to define custom auxiliary data or type inference that doesn't
        get traced by JAX. See the `ZeroOrderHold` implementation for an example.
        Since this is only applied during context initialization, it is allowed to
        modify the context directly (or the system itself).

        Typically this should not be called outside of the ContextFactory.

        Args:
            context (ContextBase): partially initialized context for this system.
        """
        return context

    @property
    def ports(self) -> dict[str, PortBase]:
        """Dictionary of all ports in this system, indexed by name"""
        return {port.name: port for port in self.input_ports + self.output_ports}

    # Convenience functions for errors and UI logs

    @property
    def name_path(self) -> list[str]:
        """Get the human-readable path to this system. None if some names are not set."""
        if self.parent is None:
            return [self.name]  # Likely to be 'root'
        if self.parent.parent is None:
            return [self.name]  # top-level block
        return self.parent.name_path + [self.name]

    @property
    def name_path_str(self) -> str:
        """Get the human-readable path to this system as a string."""
        return ".".join(self.name_path)

    @property
    def ui_id_path(self) -> Union[list[str], None]:
        """Get the uuid node path to this system. None if some IDs are not set."""
        if self.ui_id is None:
            return None
        if self.parent is None:
            return [self.ui_id]
        if self.parent.parent is None:
            return [self.ui_id]  # top-level block
        parent_path = self.parent.ui_id_path
        if parent_path is None:
            return None
        return parent_path + [self.ui_id]

    def declare_static_parameters(self, **params):
        """Declare a set of static parameters for the system.

        These parameters are not JAX-traceable and therefore can't be optimized.

        Examples of static parameters include booleans, strings, parameters
        used in shapes, etc.

        The args should be a dict of name-value pairs, where the values are either
        strings, bool, arrays, or Parameters.

        Typical usage:

        ```python
        class MyBlock(LeafSystem):
            def __init__(self, param1=True, param2=1.0):
                super().__init__()
                self.declare_static_parameters(param1=param1, param2=param2)
        ```
        """
        for name, value in params.items():
            if name in self.dynamic_parameters:
                raise BlockParameterError(
                    "Parameter already declared as dynamic parameter",
                    system=self,
                    parameter_name=name,
                )
            if isinstance(value, list):
                self._static_parameters[name] = Parameter(
                    value=np.array(value),
                    system=self,
                    is_static=True,
                )
            else:
                self._static_parameters[name] = Parameter(
                    value=value, system=self, is_static=True
                )

    def declare_static_parameter(self, name, value):
        """Declare a single static parameter for the system.

        This is a convenience function for declaring a single static parameter.

        Args:
            name (str): name of the parameter
            value (Union[Array, Parameter]): value of the parameter
        """
        self.declare_static_parameters(**{name: value})

    def declare_dynamic_parameter(
        self,
        name: str,
        default_value: Array | Parameter = None,
        shape: ShapeLike = None,
        dtype: DTypeLike = None,
        as_array: bool = True,
    ):
        """Declare a numeric parameter for the system.

        Parameters are declared in the system and accessed through the context to
        maintain separation of data ownership. This method creates an entry in the
        system's dynamic_parameters, recording the name, default value, and dependency
        ticket for later reference.

        The default value will be used to initialize the context, so it
        will also serve as the initial value unless explicitly overridden. In the
        simplest cases, parameters could be stored as attributes of the LeafSystem,
        but declaring them has the advantage of moving the values to the context,
        allowing them to be traced by JAX rather than stored as static data. This
        means they can be differentiated, vmapped, or otherwise modified without
        re-compiling the simulation.

        Args:
            name (str):
                The name of the parameter.
            default_value (Union[Array, Parameter], optional):
                The default value of the parameter. Parameters are used
                primarily internally for serialization and should not normally need
                to be used directly when implementing LeafSystems. Defaults to None.
            shape (ShapeLike, optional):
                The shape of the parameter. Defaults to None.
            dtype (DTypeLike, optional):
                The data type of the parameter. Defaults to None.
            as_array (bool, optional):
                If True, treat the default_value as an array-like (cast if necessary).
                Otherwise, it will be stored as the default state without modification.

        Raises:
            AssertionError:
                If the parameter with the given name is already declared.

        Notes:
            (1) Only one of `shape` and `default_value` should be provided. If
            `default_value` is provided, it will be used as the initial value of the
            continuous state. If `shape` is provided, the initial value will be a
            zero vector of the given shape and specified dtype.
        """
        # assert (
        #     name not in self._dynamic_parameters
        # ), f"Parameter {name} already declared"

        if name in self.static_parameters:
            raise BlockParameterError(
                "Parameter already declared as static parameter",
                system=self,
                parameter_name=name,
            )

        try:
            if isinstance(default_value, Parameter):
                self._dynamic_parameters[name] = Parameter(
                    value=default_value,
                    dtype=dtype,
                    shape=shape,
                    system=self,
                    as_array=as_array,
                )
            else:
                if as_array:
                    default_value = utils.make_array(
                        default_value, dtype=dtype, shape=shape
                    )
                self._dynamic_parameters[name] = Parameter(
                    value=default_value,
                    dtype=dtype,
                    shape=shape,
                    system=self,
                )

            logger.debug(
                "Adding parameter %s to %s with default: %s",
                name,
                self.name,
                default_value,
            )

        except Exception as e:
            traceback.print_exc()
            raise BlockParameterError(
                "Error declaring parameter",
                system=self,
                parameter_name=name,
            ) from e

    @property
    def has_dirty_static_parameters(self) -> bool:
        """Check if any static parameters have been modified."""
        return any(param.is_dirty for param in self.static_parameters.values())

dependency_graph: DependencyGraph property

Retrieve (or create if necessary) the dependency graph for this system.

has_dirty_static_parameters: bool property

Check if any static parameters have been modified.

has_mass_matrix: bool abstractmethod property

Returns True if any component of the system has a nontrivial mass matrix.

mass_matrix: StateComponent abstractmethod property

Mass matrix for this system.

Returns PyTree-structured data where each leaf is an (n, n) array. This is used for implicit integration methods (currently only BDF).

name_path: list[str] property

Get the human-readable path to this system. None if some names are not set.

name_path_str: str property

Get the human-readable path to this system as a string.

ports: dict[str, PortBase] property

Dictionary of all ports in this system, indexed by name

sorted_event_callbacks: List[SystemCallback] property

Sort and return the event-related callbacks for this system.

ui_id_path: Union[list[str], None] property

Get the uuid node path to this system. None if some IDs are not set.

check_types(context, error_collector=None)

Perform any system-specific static analysis.

Source code in collimator/framework/system_base.py
926
927
928
def check_types(self, context: ContextBase, error_collector: ErrorCollector = None):
    """Perform any system-specific static analysis."""
    pass

collect_inputs(context, port_indices=None)

Collect all current inputs for this system.

Parameters:

Name Type Description Default
context ContextBase

root context for this system

required
port_indices List[int]

list of input port indices to collect. If None (default), will return values from all ports. Otherwise will return a list of length(num_input_ports), where the values are None for ports not in the list.

None

Returns:

Type Description
List[Array]

List[Array]: list of all current input values

Source code in collimator/framework/system_base.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
def collect_inputs(
    self, context: ContextBase, port_indices: list[int] = None
) -> List[Array]:
    """Collect all current inputs for this system.

    Args:
        context (ContextBase): root context for this system
        port_indices (List[int], optional): list of input port indices to collect.
            If None (default), will return values from all ports.  Otherwise will
            return a list of length(num_input_ports), where the values are None for
            ports not in the list.

    Returns:
        List[Array]: list of all current input values
    """
    if port_indices is None:
        port_indices = range(self.num_input_ports)

    # Some blocks are hard-coded to have no inputs, so we should make
    # sure that a list full of None is not returned in that case. This
    # happens if the callback signature is (time, state, **parameters)
    # instead of the more general (time, state, *inputs, **parameters)
    if port_indices == []:
        return []

    # Right now this seems to be the best place to add the cache.
    # FIXME: Caching outputs could work better if we knew which ones
    # to target specifically (the more expensive ones).
    inputs = self._basic_output_cache.get(context)

    # If inputs is not None but is not consistent with the requested
    # port_indices, we should recompute the inputs.  This is only an issue
    # when using the NumPy-backend caching.
    if inputs is not None:
        if any(inputs[i] is None for i in port_indices):
            inputs = None

    if inputs is None:
        inputs = []
        for i in range(self.num_input_ports):
            u_i = self.eval_input(context, i) if i in port_indices else None
            inputs.append(u_i)
        self._basic_output_cache.set(context, inputs)

    return inputs

configure_output_port(port_index, callback, prerequisites_of_calc=None, default_value=None, event=None, cache_index=None)

Configure an output port of the system.

See declare_output_port for a description of the arguments.

Parameters:

Name Type Description Default
port_index int

index of the output port to configure

required

Returns:

Type Description

None

Source code in collimator/framework/system_base.py
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
def configure_output_port(
    self,
    port_index: int,
    callback: Callable,
    prerequisites_of_calc: List[DependencyTicket] = None,
    default_value: Array = None,
    event: DiscreteUpdateEvent = None,
    cache_index: int = None,
):
    """Configure an output port of the system.

    See `declare_output_port` for a description of the arguments.

    Args:
        port_index (int): index of the output port to configure

    Returns:
        None
    """

    if prerequisites_of_calc is None:
        prerequisites_of_calc = [DependencyTicket.all_sources]

    port = self.output_ports[port_index]
    port.port_index = port_index
    port._callback = callback
    port.prerequisites_of_calc = prerequisites_of_calc
    port.default_value = default_value
    port.event = event
    port.cache_index = cache_index
    self.callbacks[port.callback_index] = port
    self._cached_output_ports.clear()

    logger.debug(
        "    ---> %s now has %s output ports: %s",
        self.name,
        len(self.output_ports),
        self.output_ports,
    )
    logger.debug(
        "    ---> %s now has %s cache sources: %s",
        self.name,
        len(self.callbacks),
        self.callbacks,
    )

context_factory()

Factory object for creating contexts for this system.

Should not be called directly - use system.create_context instead.

Source code in collimator/framework/system_base.py
930
931
932
933
934
935
936
@abc.abstractproperty
def context_factory(self) -> ContextFactory:
    """Factory object for creating contexts for this system.

    Should not be called directly - use `system.create_context` instead.
    """
    pass

create_context(**kwargs)

Create a new context for this system.

The context will contain all variable information used in simulation/analysis/optimization, such as state and parameters.

Returns:

Name Type Description
ContextBase ContextBase

new context for this system

Source code in collimator/framework/system_base.py
915
916
917
918
919
920
921
922
923
924
def create_context(self, **kwargs) -> ContextBase:
    """Create a new context for this system.

    The context will contain all variable information used in
    simulation/analysis/optimization, such as state and parameters.

    Returns:
        ContextBase: new context for this system
    """
    return self.context_factory(**kwargs)

create_dependency_graph()

Create a dependency graph for this system.

Source code in collimator/framework/system_base.py
951
952
953
def create_dependency_graph(self):
    """Create a dependency graph for this system."""
    self._dependency_graph = self.dependency_graph_factory()

declare_dynamic_parameter(name, default_value=None, shape=None, dtype=None, as_array=True)

Declare a numeric parameter for the system.

Parameters are declared in the system and accessed through the context to maintain separation of data ownership. This method creates an entry in the system's dynamic_parameters, recording the name, default value, and dependency ticket for later reference.

The default value will be used to initialize the context, so it will also serve as the initial value unless explicitly overridden. In the simplest cases, parameters could be stored as attributes of the LeafSystem, but declaring them has the advantage of moving the values to the context, allowing them to be traced by JAX rather than stored as static data. This means they can be differentiated, vmapped, or otherwise modified without re-compiling the simulation.

Parameters:

Name Type Description Default
name str

The name of the parameter.

required
default_value Union[Array, Parameter]

The default value of the parameter. Parameters are used primarily internally for serialization and should not normally need to be used directly when implementing LeafSystems. Defaults to None.

None
shape ShapeLike

The shape of the parameter. Defaults to None.

None
dtype DTypeLike

The data type of the parameter. Defaults to None.

None
as_array bool

If True, treat the default_value as an array-like (cast if necessary). Otherwise, it will be stored as the default state without modification.

True

Raises:

Type Description
AssertionError

If the parameter with the given name is already declared.

Notes

(1) Only one of shape and default_value should be provided. If default_value is provided, it will be used as the initial value of the continuous state. If shape is provided, the initial value will be a zero vector of the given shape and specified dtype.

Source code in collimator/framework/system_base.py
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
def declare_dynamic_parameter(
    self,
    name: str,
    default_value: Array | Parameter = None,
    shape: ShapeLike = None,
    dtype: DTypeLike = None,
    as_array: bool = True,
):
    """Declare a numeric parameter for the system.

    Parameters are declared in the system and accessed through the context to
    maintain separation of data ownership. This method creates an entry in the
    system's dynamic_parameters, recording the name, default value, and dependency
    ticket for later reference.

    The default value will be used to initialize the context, so it
    will also serve as the initial value unless explicitly overridden. In the
    simplest cases, parameters could be stored as attributes of the LeafSystem,
    but declaring them has the advantage of moving the values to the context,
    allowing them to be traced by JAX rather than stored as static data. This
    means they can be differentiated, vmapped, or otherwise modified without
    re-compiling the simulation.

    Args:
        name (str):
            The name of the parameter.
        default_value (Union[Array, Parameter], optional):
            The default value of the parameter. Parameters are used
            primarily internally for serialization and should not normally need
            to be used directly when implementing LeafSystems. Defaults to None.
        shape (ShapeLike, optional):
            The shape of the parameter. Defaults to None.
        dtype (DTypeLike, optional):
            The data type of the parameter. Defaults to None.
        as_array (bool, optional):
            If True, treat the default_value as an array-like (cast if necessary).
            Otherwise, it will be stored as the default state without modification.

    Raises:
        AssertionError:
            If the parameter with the given name is already declared.

    Notes:
        (1) Only one of `shape` and `default_value` should be provided. If
        `default_value` is provided, it will be used as the initial value of the
        continuous state. If `shape` is provided, the initial value will be a
        zero vector of the given shape and specified dtype.
    """
    # assert (
    #     name not in self._dynamic_parameters
    # ), f"Parameter {name} already declared"

    if name in self.static_parameters:
        raise BlockParameterError(
            "Parameter already declared as static parameter",
            system=self,
            parameter_name=name,
        )

    try:
        if isinstance(default_value, Parameter):
            self._dynamic_parameters[name] = Parameter(
                value=default_value,
                dtype=dtype,
                shape=shape,
                system=self,
                as_array=as_array,
            )
        else:
            if as_array:
                default_value = utils.make_array(
                    default_value, dtype=dtype, shape=shape
                )
            self._dynamic_parameters[name] = Parameter(
                value=default_value,
                dtype=dtype,
                shape=shape,
                system=self,
            )

        logger.debug(
            "Adding parameter %s to %s with default: %s",
            name,
            self.name,
            default_value,
        )

    except Exception as e:
        traceback.print_exc()
        raise BlockParameterError(
            "Error declaring parameter",
            system=self,
            parameter_name=name,
        ) from e

declare_input_port(name=None, prerequisites_of_calc=None)

Add an input port to the system.

Returns the corresponding index into the system input_port_indices list Note that this is different from the callbacks index - typically it will make more sense to retrieve via system.input_ports[port_index], but

Parameters:

Name Type Description Default
name str

name of the new port. Defaults to None, which will use the default naming scheme for the system (e.g. "u_0")

None
prerequisites_of_calc List[DependencyTicket]

list of dependencies for the callback function. Defaults to None.

None

Returns:

Name Type Description
int int

port index of the newly created port in input_ports

Source code in collimator/framework/system_base.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
def declare_input_port(
    self,
    name: str = None,
    prerequisites_of_calc: List[DependencyTicket] = None,
) -> int:
    """Add an input port to the system.

    Returns the corresponding index into the system input_port_indices list
    Note that this is different from the callbacks index - typically it
    will make more sense to retrieve via system.input_ports[port_index], but

    Args:
        name (str, optional): name of the new port. Defaults to None, which will
            use the default naming scheme for the system (e.g. "u_0")
        prerequisites_of_calc (List[DependencyTicket], optional): list of
            dependencies for the callback function. Defaults to None.

    Returns:
        int: port index of the newly created port in `input_ports`
    """
    port_index = self.num_input_ports
    port_name = self._next_input_port_name(name)

    for port in self.input_ports:
        assert (
            port.name != port_name
        ), f"System {self.name} already has an input port named {port.name}"

    def _callback(context: ContextBase) -> Array:
        # Given the root context, evaluate the input port using the helper function
        return self._eval_input_port(context, port_index)

    callback_index = len(self.callbacks)
    port = InputPort(
        callback=_callback,
        system=self,
        callback_index=callback_index,
        name=port_name,
        index=port_index,
        prerequisites_of_calc=prerequisites_of_calc,
    )

    assert isinstance(port, InputPort)
    assert port.system is self
    assert port.name != ""

    # Check that name is unique
    for p in self.input_ports:
        assert (
            p.name != port.name
        ), f"System {self.name} already has an input port named {port.name}"

    self.input_port_indices.append(callback_index)
    self.callbacks.append(port)
    self._cached_input_ports.clear()

    return port_index

declare_output_port(callback, name=None, prerequisites_of_calc=None, default_value=None, event=None, cache_index=None)

Add an output port to the system.

This output port could represent any function of the context available to the system, so a callback function is required. This function should have the form callback(context: ContextBase) -> Array SystemBase implementations have some specific convenience wrappers, e.g.: LeafSystem.declare_continuous_state_output Diagram.export_output

Common cases are: - Feedthrough blocks: gather inputs and return some function of the inputs (e.g. a gain) - Stateful blocks: use LeafSystem.declare_(...)_state_output_port to return the value of a particular state - Diagrams: create and export a diagram-level port to the parent system using the callback function associated with the system-level port

Returns the corresponding index into the system output_port_indices list Note that this is different from the callbacks index - typically it will make more sense to retrieve via system.output_ports[port_index].

Parameters:

Name Type Description Default
callback Callable

computes the value of the output port given the root context.

required
name str

name of the new port. Defaults to None, which will use the default naming scheme for the system (e.g. "y_0")

None
prerequisites_of_calc List[DependencyTicket]

list of dependencies for the callback function. Defaults to None, which will use the default dependencies for the system (all sources). This may conservatively flag the system as having algebraic loops, so it is better to be specific here when possible. This is done automatically in the wrapper functions like LeafSystem.declare_(...)_output_port

None
default_value Array

A default array-like value used to seed the context and perform type inference, when this is known up front. Defaults to None, which will use information propagation through the graph along with type promotion to determine an appropriate value.

None
event DiscreteUpdateEvent

A discrete update event associated with this output port that will periodically refresh the value that will be returned by the callback function. This makes the port act as a sample-and-hold rather than a direct function evaluation.

None
cache_index int

Index into the cache state component corresponding to the output port result, if the output port is of periodically-updated sample-and-hold type.

None

Returns:

Name Type Description
int int

port index of the newly created port

Source code in collimator/framework/system_base.py
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
def declare_output_port(
    self,
    callback: Callable,
    name: str = None,
    prerequisites_of_calc: List[DependencyTicket] = None,
    default_value: Array = None,
    event: DiscreteUpdateEvent = None,
    cache_index: int = None,
) -> int:
    """Add an output port to the system.

    This output port could represent any function of the context available to
    the system, so a callback function is required.  This function should have
    the form
        `callback(context: ContextBase) -> Array`
    SystemBase implementations have some specific convenience wrappers, e.g.:
        `LeafSystem.declare_continuous_state_output`
        `Diagram.export_output`

    Common cases are:
    - Feedthrough blocks: gather inputs and return some function of the
        inputs (e.g. a gain)
    - Stateful blocks: use LeafSystem.declare_(...)_state_output_port to
        return the value of a particular state
    - Diagrams: create and export a diagram-level port to the parent system using
        the callback function associated with the system-level port

    Returns the corresponding index into the system output_port_indices list
    Note that this is different from the callbacks index - typically it
    will make more sense to retrieve via system.output_ports[port_index].

    Args:
        callback (Callable, optional): computes the value of the output port given
            the root context.
        name (str, optional): name of the new port. Defaults to None, which will
            use the default naming scheme for the system (e.g. "y_0")
        prerequisites_of_calc (List[DependencyTicket], optional): list of
            dependencies for the callback function. Defaults to None, which will
            use the default dependencies for the system (all sources).  This may
            conservatively flag the system as having algebraic loops, so it is
            better to be specific here when possible.  This is done automatically
            in the wrapper functions like `LeafSystem.declare_(...)_output_port`
        default_value (Array, optional): A default array-like value used to seed
            the context and perform type inference, when this is known up front.
            Defaults to None, which will use information propagation through the
            graph along with type promotion to determine an appropriate value.
        event (DiscreteUpdateEvent, optional): A discrete update event associated
            with this output port that will periodically refresh the value that
            will be returned by the callback function. This makes the port act as
            a sample-and-hold rather than a direct function evaluation.
        cache_index (int, optional): Index into the cache state component
            corresponding to the output port result, if the output port is of
            periodically-updated sample-and-hold type.

    Returns:
        int: port index of the newly created port
    """
    port_index = self.num_output_ports
    port_name = self._next_output_port_name(name)

    for port in self.output_ports:
        assert (
            port.name != port_name
        ), f"System {self.name} already has an output port named {port.name}"

    if prerequisites_of_calc is None:
        prerequisites_of_calc = [DependencyTicket.all_sources]

    callback_index = len(self.callbacks)
    port = OutputPort(
        callback,
        system=self,
        callback_index=callback_index,
        name=port_name,
        index=port_index,
        prerequisites_of_calc=prerequisites_of_calc,
        default_value=default_value,
        event=event,
        cache_index=cache_index,
    )

    assert isinstance(port, OutputPort)
    assert port.system is self
    assert port.name != ""

    # Check that name is unique
    for p in self.output_ports:
        assert (
            p.name != port.name
        ), f"System {self.name} already has an output port named {port.name}"

    logger.debug("Adding output port %s to %s", port, self.name)
    self.output_port_indices.append(callback_index)
    self.callbacks.append(port)
    self._cached_output_ports.clear()

    logger.debug(
        "    ---> %s now has %s output ports: %s",
        self.name,
        len(self.output_ports),
        self.output_ports,
    )
    logger.debug(
        "    ---> %s now has %s cache sources: %s",
        self.name,
        len(self.callbacks),
        self.callbacks,
    )

    return port_index

declare_static_parameter(name, value)

Declare a single static parameter for the system.

This is a convenience function for declaring a single static parameter.

Parameters:

Name Type Description Default
name str

name of the parameter

required
value Union[Array, Parameter]

value of the parameter

required
Source code in collimator/framework/system_base.py
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
def declare_static_parameter(self, name, value):
    """Declare a single static parameter for the system.

    This is a convenience function for declaring a single static parameter.

    Args:
        name (str): name of the parameter
        value (Union[Array, Parameter]): value of the parameter
    """
    self.declare_static_parameters(**{name: value})

declare_static_parameters(**params)

Declare a set of static parameters for the system.

These parameters are not JAX-traceable and therefore can't be optimized.

Examples of static parameters include booleans, strings, parameters used in shapes, etc.

The args should be a dict of name-value pairs, where the values are either strings, bool, arrays, or Parameters.

Typical usage:

class MyBlock(LeafSystem):
    def __init__(self, param1=True, param2=1.0):
        super().__init__()
        self.declare_static_parameters(param1=param1, param2=param2)
Source code in collimator/framework/system_base.py
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
def declare_static_parameters(self, **params):
    """Declare a set of static parameters for the system.

    These parameters are not JAX-traceable and therefore can't be optimized.

    Examples of static parameters include booleans, strings, parameters
    used in shapes, etc.

    The args should be a dict of name-value pairs, where the values are either
    strings, bool, arrays, or Parameters.

    Typical usage:

    ```python
    class MyBlock(LeafSystem):
        def __init__(self, param1=True, param2=1.0):
            super().__init__()
            self.declare_static_parameters(param1=param1, param2=param2)
    ```
    """
    for name, value in params.items():
        if name in self.dynamic_parameters:
            raise BlockParameterError(
                "Parameter already declared as dynamic parameter",
                system=self,
                parameter_name=name,
            )
        if isinstance(value, list):
            self._static_parameters[name] = Parameter(
                value=np.array(value),
                system=self,
                is_static=True,
            )
        else:
            self._static_parameters[name] = Parameter(
                value=value, system=self, is_static=True
            )

dependency_graph_factory()

Factory object for creating dependency graphs for this system.

Should not be called directly - use system.create_dependency_graph instead.

Source code in collimator/framework/system_base.py
943
944
945
946
947
948
949
@abc.abstractproperty
def dependency_graph_factory(self) -> DependencyGraphFactory:
    """Factory object for creating dependency graphs for this system.

    Should not be called directly - use `system.create_dependency_graph` instead.
    """
    pass

determine_active_guards(context) abstractmethod

Determine active guards for zero-crossing events.

This method is responsible for evaluating and determining which zero-crossing events are active based on the current system mode and other conditions. This can be overridden to flag active/inactive guards on a block-specific basis, for instance in a StateMachine-type block. By default all guards are marked active at this point unless the zero-crossing event was declared with a non-default start_mode, in which case the guard is activated conditionally on the current mode.

For example, in a system with finite state transitions, where a transition from mode A to mode B is triggered by a guard function g_AB and the inverse transition is triggered by a guard function g_BA, this function would activate g_AB if the system is in mode A and g_BA if the system is in mode B. The other guard function would be inactive. If the zero-crossing event is not associated with a starting mode, it is considered to be always active.

Parameters:

Name Type Description Default
context ContextBase

The root context containing the overall state and parameters.

required

Returns:

Name Type Description
EventCollection EventCollection

A collection of zero-crossing events with active/inactive status updated based on the current system mode and other conditions.

Source code in collimator/framework/system_base.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
@abc.abstractmethod
def determine_active_guards(self, context: ContextBase) -> EventCollection:
    """Determine active guards for zero-crossing events.

    This method is responsible for evaluating and determining which
    zero-crossing events are active based on the current system mode
    and other conditions.  This can be overridden to flag active/inactive
    guards on a block-specific basis, for instance in a StateMachine-type
    block. By default all guards are marked active at this point unless
    the zero-crossing event was declared with a non-default `start_mode`, in
    which case the guard is activated conditionally on the current mode.

    For example, in a system with finite state transitions, where a transition from
    mode A to mode B is triggered by a guard function g_AB and the inverse
    transition is triggered by a guard function g_BA, this function would activate
    g_AB if the system is in mode A and g_BA if the system is in mode B. The other
    guard function would be inactive.  If the zero-crossing event is not associated
    with a starting mode, it is considered to be always active.

    Args:
        context (ContextBase):
            The root context containing the overall state and parameters.

    Returns:
        EventCollection:
            A collection of zero-crossing events with active/inactive status
            updated based on the current system mode and other conditions.
    """
    pass

eval_input(context, port_index=0)

Get the input for a given port.

This works by evaluating the callback function associated with the port, which will "pull" the upstream output port values.

Parameters:

Name Type Description Default
context ContextBase

root context for this system

required
port_index int

index into self.input_ports, for example the value returned by declare_input_port. Defaults to 0.

0

Returns:

Name Type Description
Array Array

current input values

Source code in collimator/framework/system_base.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
def eval_input(self, context: ContextBase, port_index: int = 0) -> Array:
    """Get the input for a given port.

    This works by evaluating the callback function associated with the port, which
    will "pull" the upstream output port values.

    Args:
        context (ContextBase): root context for this system
        port_index (int, optional): index into `self.input_ports`, for example
            the value returned by `declare_input_port`. Defaults to 0.

    Returns:
        Array: current input values
    """
    return self.input_ports[port_index].eval(context)

eval_time_derivatives(context)

Evaluate the continuous time derivatives for this system.

Given the root context, evaluate the continuous time derivatives, which must have the same PyTree structure as the continuous state.

In principle, this can be overridden by custom implementations, but in general it is preferable to declare continuous states for LeafSystems using declare_continuous_state, which accepts a callback function that will be used to compute the derivatives. For Diagrams, the time derivatives are computed automatically using the callback functions for all child systems with continuous state.

Parameters:

Name Type Description Default
context ContextBase

root context of this system

required

Returns:

Name Type Description
StateComponent StateComponent

Continuous time derivatives for this system, or None if the system has no continuous state.

Source code in collimator/framework/system_base.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def eval_time_derivatives(self, context: ContextBase) -> StateComponent:
    """Evaluate the continuous time derivatives for this system.

    Given the _root_ context, evaluate the continuous time derivatives,
    which must have the same PyTree structure as the continuous state.

    In principle, this can be overridden by custom implementations, but
    in general it is preferable to declare continuous states for LeafSystems
    using `declare_continuous_state`, which accepts a callback function
    that will be used to compute the derivatives. For Diagrams, the time
    derivatives are computed automatically using the callback functions for
    all child systems with continuous state.

    Args:
        context (ContextBase): root context of this system

    Returns:
        StateComponent:
            Continuous time derivatives for this system, or None if the system
            has no continuous state.
    """
    return None

eval_zero_crossing_updates(context, events) abstractmethod

Evaluate reset maps associated with zero-crossing events.

Parameters:

Name Type Description Default
context ContextBase

The context for the system, containing the current state and parameters.

required
events EventCollection

The collection of events to be evaluated (for example zero-crossing or periodic events for this system).

required

Returns:

Name Type Description
State State

The complete state with all updates applied.

Notes

(1) Following the Drake definition, "unrestricted" updates are allowed to modify any component of the state: continuous, discrete, or mode. These updates are evaluated in the order in which they were declared, so it is possible (but should be strictly avoided) for multiple events to modify the same state component at the same time.

Each update computes its results given the current state of the system (the "minus" values) and returns the updated state (the "plus" values). The update functions cannot access any information about the "plus" values of its own state or the state of any other block. This could change in the future but for now it ensures consistency with Drake's discrete semantices:

More specifically, since all unrestricted updates can modify the entire state, any time there are multiple unrestricted updates, the resulting states are ALWAYS in conflict. For example, suppose a system has two unrestricted updates, event1 and event2. At time t_n, event1 is active and event2 is inactive. First, event1 is evaluated, and the state is updated. Then event2 is evaluated, but the state is not updated. Which one is valid? Obviously, the event1 return is valid, but how do we communicate this to JAX? The situation is more complicated if both event1 and event2 happen to be active. In this case the states have to be "merged" somehow. In the worst case, these two will modify the same components of the state in different ways.

The implementation updates the state in a local copy of the context (since both are immutable). This allows multiple unrestricted updates, but leaves open the possibility of multiple active updates modifying the state in conflicting ways. This should strictly be avoided by the implementer of the LeafSystem. If it is at all unclear how to do this, it may be better to split the system into multiple blocks to be safe.

(2) The events are evaluated conditionally on being marked "active" (indicating that their guard function triggered), so the entire event collection can be passed without filtering to active events. This is necessary to make the function calls work with JAX tracing, which do not allow for variable-sized arguments or returns.

Source code in collimator/framework/system_base.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
@abc.abstractmethod
def eval_zero_crossing_updates(
    self,
    context: ContextBase,
    events: EventCollection,
) -> State:
    """Evaluate reset maps associated with zero-crossing events.

    Args:
        context (ContextBase):
            The context for the system, containing the current state and parameters.
        events (EventCollection):
            The collection of events to be evaluated (for example zero-crossing or
            periodic events for this system).

    Returns:
        State: The complete state with all updates applied.

    Notes:
        (1) Following the Drake definition, "unrestricted" updates are allowed to
        modify any component of the state: continuous, discrete, or mode.  These
        updates are evaluated in the order in which they were declared, so it is
        _possible_ (but should be strictly avoided) for multiple events to modify the
        same state component at the same time.

        Each update computes its results given the _current_ state of the system
        (the "minus" values) and returns the _updated_ state (the "plus" values).
        The update functions cannot access any information about the "plus" values of
        its own state or the state of any other block.  This could change in the future
        but for now it ensures consistency with Drake's discrete semantices:

        More specifically, since all unrestricted updates can modify the entire state,
        any time there are multiple unrestricted updates, the resulting states are
        ALWAYS in conflict.  For example, suppose a system has two unrestricted
        updates, `event1` and `event2`.  At time t_n, `event1` is active and `event2`
        is inactive.  First, `event1` is evaluated, and the state is updated.  Then
        `event2` is evaluated, but the state is not updated.  Which one is valid?
        Obviously, the `event1` return is valid, but how do we communicate this to JAX?
        The situation is more complicated if both `event1` and `event2` happen to be
        active.  In this case the states have to be "merged" somehow.  In the worst
        case, these two will modify the same components of the state in different ways.

        The implementation updates the state in a local copy of the context (since both
        are immutable).  This allows multiple unrestricted updates, but leaves open the
        possibility of multiple active updates modifying the state in conflicting ways.
        This should strictly be avoided by the implementer of the LeafSystem.  If it is
        at all unclear how to do this, it may be better to split the system into
        multiple blocks to be safe.

        (2) The events are evaluated conditionally on being marked "active"
        (indicating that their guard function triggered), so the entire event
        collection can be passed without filtering to active events. This is necessary
        to make the function calls work with JAX tracing, which do not allow for
        variable-sized arguments or returns.
    """
    pass

get_feedthrough() abstractmethod

Determine pairs of direct feedthrough ports for this system.

By default, the algorithm relies on the dependency tracking system to determine feedthrough, but this can be overridden by implementing this method directly in a subclass, for instance if the automatic dependency tracking is too conservative in determining feedthrough.

Returns:

Type Description
List[Tuple[int, int]]

List[Tuple[int, int]]: A list of tuples (u, v) indicating that output port v has a direct dependency on input port u, resulting in a feedthrough path in the system. The indices u and v correspond to the indices of the input and output ports in the system's input and output port lists.

Source code in collimator/framework/system_base.py
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
@abc.abstractmethod
def get_feedthrough(self) -> List[Tuple[int, int]]:
    """Determine pairs of direct feedthrough ports for this system.

    By default, the algorithm relies on the dependency tracking system to determine
    feedthrough, but this can be overridden by implementing this method directly
    in a subclass, for instance if the automatic dependency tracking is too
    conservative in determining feedthrough.

    Returns:
        List[Tuple[int, int]]:
            A list of tuples (u, v) indicating that output port v has a direct
            dependency on input port u, resulting in a feedthrough path in the system.
            The indices u and v correspond to the indices of the input and output
            ports in the system's input and output port lists.
    """
    pass

get_input_port(name)

Retrieve a specific input port by name.

Source code in collimator/framework/system_base.py
525
526
527
528
529
530
531
532
533
def get_input_port(self, name: str) -> tuple[InputPort, int]:
    """Retrieve a specific input port by name."""
    for i, port in enumerate(self.input_ports):
        if port.name == name:
            return port, i
    raise ValueError(
        f"System {self.name} has no input port named {name}. "
        f"Available ports: {list(map(lambda x: x.name,self.input_ports))}"
    )

get_output_port(name)

Retrieve a specific output port by name.

Source code in collimator/framework/system_base.py
547
548
549
550
551
552
def get_output_port(self, name: str) -> OutputPort:
    """Retrieve a specific output port by name."""
    for port in self.output_ports:
        if port.name == name:
            return port
    raise ValueError(f"System {self.name} has no output port named {name}")

handle_discrete_update(events, context)

Compute and apply active discrete updates.

Given the root context, evaluate the discrete updates, which must have the same PyTree structure as the discrete states of this system. This should be a pure function, so that it does not modify any aspect of the context in-place (even though it is difficult to strictly prevent this in Python).

This will evaluate the set of events that result from declaring state or output update events on systems using LeafSystem.declare_periodic_update and LeafSystem.declare_output_port with an associated periodic update rate.

This is intended for internal use by the simulator and should not normally need to be invoked directly by users. Events are evaluated conditionally on being marked "active", so the entire event collection can be passed without filtering to active events. This is necessary to make the function calls work with JAX tracing, which do not allow for variable-sized arguments or returns.

For a discrete system updating at a particular rate, the update rule for a particular block is:

x[n+1] = f(t[n], x[n], u[n])
y[n]   = g(t[n], x[n], u[n])

Additionally, the value y[n] is held constant until the next update from the point of view of other continuous-time or asynchronous discrete-time blocks.

Because each output y[n] may in general depend on the input u[n] evaluated at the same time, the composite discrete update function represents a system of equations. However, since algebraic loops are prohibited, the events can be ordered and executed sequentially to ensure that the updates are applied in the correct order. This is implemented in SystemBase.sorted_event_callbacks.

Multirate systems work in the same way, except that the events are evaluated conditionally on whether the current time corresponds to an update time for each event.

Parameters:

Name Type Description Default
events EventCollection

collection of discrete update events

required
context ContextBase

root context for this system

required

Returns:

Name Type Description
ContextBase ContextBase

updated context with all active updates applied to the discrete state

Source code in collimator/framework/system_base.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def handle_discrete_update(
    self, events: EventCollection, context: ContextBase
) -> ContextBase:
    """Compute and apply active discrete updates.

    Given the _root_ context, evaluate the discrete updates, which must have the
    same PyTree structure as the discrete states of this system. This should be
    a pure function, so that it does not modify any aspect of the context in-place
    (even though it is difficult to strictly prevent this in Python).

    This will evaluate the set of events that result from declaring state or output
    update events on systems using `LeafSystem.declare_periodic_update` and
    `LeafSystem.declare_output_port` with an associated periodic update rate.

    This is intended for internal use by the simulator and should not normally need
    to be invoked directly by users. Events are evaluated conditionally on being
    marked "active", so the entire event collection can be passed without filtering
    to active events. This is necessary to make the function calls work with JAX
    tracing, which do not allow for variable-sized arguments or returns.

    For a discrete system updating at a particular rate, the update rule for a
    particular block is:

    ```
    x[n+1] = f(t[n], x[n], u[n])
    y[n]   = g(t[n], x[n], u[n])
    ```

    Additionally, the value y[n] is held constant until the next update from the
    point of view of other continuous-time or asynchronous discrete-time blocks.

    Because each output `y[n]` may in general depend on the input `u[n]` evaluated
    _at the same time_, the composite discrete update function represents a
    system of equations.  However, since algebraic loops are prohibited, the events
    can be ordered and executed sequentially to ensure that the updates are applied
    in the correct order.  This is implemented in
    `SystemBase.sorted_event_callbacks`.

    Multirate systems work in the same way, except that the events are evaluated
    conditionally on whether the current time corresponds to an update time for each
    event.

    Args:
        events (EventCollection): collection of discrete update events
        context (ContextBase): root context for this system

    Returns:
        ContextBase:
            updated context with all active updates applied to the discrete state
    """
    logger.debug(
        f"Handling {events.num_events} discrete update events at t={context.time}"
    )
    if events.has_events:
        # Sequentially apply all updates
        for event in events:
            system_id = event.system_id
            state = event.handle(context)
            local_context = context[system_id].with_state(state)
            context = context.with_subcontext(system_id, local_context)

    return context

handle_zero_crossings(events, context)

Compute and apply active zero-crossing events.

This is intended for internal use by the simulator and should not normally need to be invoked directly by users. Events are evaluated conditionally on being marked "active", so the entire event collection can be passed without filtering to active events. This is necessary to make the function calls work with JAX tracing, which do not allow for variable-sized arguments or returns.

Parameters:

Name Type Description Default
events EventCollection

collection of zero-crossing events

required
context ContextBase

root context for this system

required

Returns:

Name Type Description
ContextBase ContextBase

updated context with all active zero-crossing events applied

Source code in collimator/framework/system_base.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def handle_zero_crossings(
    self, events: EventCollection, context: ContextBase
) -> ContextBase:
    """Compute and apply active zero-crossing events.

    This is intended for internal use by the simulator and should not normally need
    to be invoked directly by users. Events are evaluated conditionally on being
    marked "active", so the entire event collection can be passed without filtering
    to active events. This is necessary to make the function calls work with JAX
    tracing, which do not allow for variable-sized arguments or returns.

    Args:
        events (EventCollection): collection of zero-crossing events
        context (ContextBase): root context for this system

    Returns:
        ContextBase: updated context with all active zero-crossing events applied
    """
    logger.debug(
        "Handling %d state transition events at t=%s",
        events.num_events,
        context.time,
    )
    if events.num_events > 0:
        # Compute the updated state for all active events
        state = self.eval_zero_crossing_updates(context, events)

        # Apply the updates
        context = context.with_state(state)
    return context

has_feedthrough_side_effects()

Check if the system includes any feedthrough calls to io_callback.

Source code in collimator/framework/system_base.py
221
222
223
224
225
226
227
228
229
230
@abc.abstractproperty
def has_feedthrough_side_effects(self) -> bool:
    """Check if the system includes any feedthrough calls to `io_callback`."""
    # This is a tricky one to explain and is almost always False except for a
    # PythonScript block that is not JAX traced.  Basically, if the output of
    # the system is used as an ODE right-hand-side, will it fail in the case where
    # the ODE solver defines a custom VJP?  This happens in diffrax, so for example
    # if a PythonScript block is used to compute the ODE right-hand-side, it will
    # fail with "Effects not supported in `custom_vjp`"
    pass

has_ode_side_effects()

Check if the ODE RHS for the system includes any calls to io_callback.

Source code in collimator/framework/system_base.py
232
233
234
235
236
237
238
@abc.abstractproperty
def has_ode_side_effects(self) -> bool:
    """Check if the ODE RHS for the system includes any calls to `io_callback`."""
    # This flag indicates that the system `has_feedthrough_side_effects` AND that
    # signal is used as an ODE right-hand-side.  This is used to determine whether
    # a JAX ODE solver can be used to integrate the system.
    pass

initialize_static_data(context)

Initialize any context data that has to be done after context creation.

Use this to define custom auxiliary data or type inference that doesn't get traced by JAX. See the ZeroOrderHold implementation for an example. Since this is only applied during context initialization, it is allowed to modify the context directly (or the system itself).

Typically this should not be called outside of the ContextFactory.

Parameters:

Name Type Description Default
context ContextBase

partially initialized context for this system.

required
Source code in collimator/framework/system_base.py
955
956
957
958
959
960
961
962
963
964
965
966
967
968
def initialize_static_data(self, context: ContextBase) -> ContextBase:
    """Initialize any context data that has to be done after context creation.

    Use this to define custom auxiliary data or type inference that doesn't
    get traced by JAX. See the `ZeroOrderHold` implementation for an example.
    Since this is only applied during context initialization, it is allowed to
    modify the context directly (or the system itself).

    Typically this should not be called outside of the ContextFactory.

    Args:
        context (ContextBase): partially initialized context for this system.
    """
    return context

post_simulation_finalize()

Finalize the system after simulation has completed.

This is only intended for special blocks that need to clean up resources and close files.

Source code in collimator/framework/system_base.py
196
197
198
199
200
def post_simulation_finalize(self) -> None:
    """Finalize the system after simulation has completed.

    This is only intended for special blocks that need to clean up
    resources and close files."""

pprint(output=print, fancy=True)

Pretty-print the system and its hierarchy.

Source code in collimator/framework/system_base.py
187
188
189
def pprint(self, output=print, fancy=True) -> str:
    """Pretty-print the system and its hierarchy."""
    output(self._pprint_helper(fancy=fancy).strip())

SystemCallback dataclass

A function associated with a system that has has specified dependencies.

This can include port update rules, discrete update functions, the right-hand-side of an ODE, etc. Storing these functions as SystemCallbacks allows the system, or a Diagram containing the system, to track dependencies across the system or diagram.

Attributes:

Name Type Description
system SystemBase

The system that owns this callback.

ticket int

The dependency ticket associated with this callback. See DependencyTicket for built-in tickets. If None, a new ticket will be generated.

name str

A short description of this callback function.

prerequisites_of_calc List[DependencyTicket]

Direct prerequisites of the computation, used for dependency tracking. These might be built-in tickets or tickets associated with other SystemCallbacks.

default_value Array

A dummy value of the same shape/dtype as the result, if known. If None, any type checking will rely on propagating upstream information via the callback.

callback_index int

The index of this function in the system's list of associated callbacks.

event Event

Optionally, the callback function may be associated with an event. If so, the associated trackers can be used to sort event execution order in addition to the regular callback execution order. For example, if an OutputPort is of sample-and-hold type, then this will be the event that periodically updates the output value. Default is None.

Source code in collimator/framework/cache.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@dataclasses.dataclass
class SystemCallback:
    """A function associated with a system that has has specified dependencies.

    This can include port update rules, discrete update functions, the right-hand-side
    of an ODE, etc. Storing these functions as SystemCallbacks allows the system, or a
    Diagram containing the system, to track dependencies across the system or diagram.

    Attributes:
        system (SystemBase):
            The system that owns this callback.
        ticket (int):
            The dependency ticket associated with this callback.  See DependencyTicket
            for built-in tickets. If None, a new ticket will be generated.
        name (str):
            A short description of this callback function.
        prerequisites_of_calc (List[DependencyTicket]):
            Direct prerequisites of the computation, used for dependency tracking.
            These might be built-in tickets or tickets associated with other
            SystemCallbacks.
        default_value (Array):
            A dummy value of the same shape/dtype as the result, if known.  If None,
            any type checking will rely on propagating upstream information via the
            callback.
        callback_index (int):
            The index of this function in the system's list of associated callbacks.
        event (Event):
            Optionally, the callback function may be associated with an event.  If so,
            the associated trackers can be used to sort event execution order in addition
            to the regular callback execution order. For example, if an OutputPort is of
            sample-and-hold type, then this will be the event that periodically updates
            the output value. Default is None.
    """

    callback: dataclasses.InitVar[Callable[[ContextBase], Array]]
    system: SystemBase
    callback_index: int
    ticket: DependencyTicket = None
    name: str = None
    prerequisites_of_calc: List[DependencyTicket] = None
    default_value: Array = None
    event: Event = None

    # If the result is cached (e.g. an output port of "sample-and-hold" type),
    # this will be the index of the cache in the system's cache list.
    cache_index: int = None

    def __post_init__(self, callback):
        self._callback = callback  # Given root context, return calculated value

        if self.ticket is None:
            self.ticket = next_dependency_ticket()
        assert isinstance(self.ticket, int)

        if self.prerequisites_of_calc is None:
            self.prerequisites_of_calc = []

        logger.debug(
            "Initialized callback %s:%s with prereqs %s",
            self.system.name_path_str,
            self.name,
            self.prerequisites_of_calc,
        )

        # A basic port output data cache for numpy
        self._basic_output_cache = BasicOutputCache(self)

    def __hash__(self) -> int:
        locator = (self.system, self.callback_index)
        return hash(locator)

    def __repr__(self) -> str:
        return f"{self.name}(ticket = {self.ticket})"

    def eval(self, root_context: ContextBase) -> Array:
        """Evaluate the callback function and return the calculated value.

        Args:
            root_context: The root context used for the evaluation.

        Returns:
            The calculated value from the callback, expected to be a Array.
        """
        if not root_context.is_initialized:
            if self.default_value is None:
                self.default_value = self._callback(root_context)
            return self.default_value

        # Note: using the BasicOutputCache here does not give any performance
        # gain, due to the overhead of computing the keys and lookups.
        try:
            result = self._callback(root_context)
        except ValueError as e:
            # this error is raised if the callback is not differentiable
            if "do not support JVP." in str(e):
                raise CallbackIsNotDifferentiableError(
                    system=self.system,
                    port_name=self.name,
                )
            raise
        return result

    @property
    def tracker(self) -> DependencyTracker:
        return self.system.dependency_graph[self.ticket]

eval(root_context)

Evaluate the callback function and return the calculated value.

Parameters:

Name Type Description Default
root_context ContextBase

The root context used for the evaluation.

required

Returns:

Type Description
Array

The calculated value from the callback, expected to be a Array.

Source code in collimator/framework/cache.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def eval(self, root_context: ContextBase) -> Array:
    """Evaluate the callback function and return the calculated value.

    Args:
        root_context: The root context used for the evaluation.

    Returns:
        The calculated value from the callback, expected to be a Array.
    """
    if not root_context.is_initialized:
        if self.default_value is None:
            self.default_value = self._callback(root_context)
        return self.default_value

    # Note: using the BasicOutputCache here does not give any performance
    # gain, due to the overhead of computing the keys and lookups.
    try:
        result = self._callback(root_context)
    except ValueError as e:
        # this error is raised if the callback is not differentiable
        if "do not support JVP." in str(e):
            raise CallbackIsNotDifferentiableError(
                system=self.system,
                port_name=self.name,
            )
        raise
    return result

ZeroCrossingEvent dataclass

Bases: Event

An event that triggers when a specified "guard" function crosses zero.

The event is triggered when the guard function crosses zero in the specified direction. In addition to the guard callback, the event also has a "reset map" which is called when the event is triggered. The reset map may update any state component in the system.

The event can also be defined as "terminal", which means that the simulation will terminate when the event is triggered. (TODO: Does the reset map still happen?)

The "direction" of the zero-crossing is one of the following: - "none": Never trigger the event (can be useful for debugging) - "positive_then_non_positive": Trigger when the guard goes from positive to non-positive - "negative_then_non_negative": Trigger when the guard goes from negative to non-negative - "crosses_zero": Trigger when the guard crosses zero in either direction - "edge_detection": Trigger when the guard changes value

Notes

This class should typically not need to be used directly by users. Instead, declare the guard function and reset map on a LeafSystem using the declare_zero_crossing method. The event will then be auto-generated for simulation.

Source code in collimator/framework/event.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class ZeroCrossingEvent(Event):
    """An event that triggers when a specified "guard" function crosses zero.

    The event is triggered when the guard function crosses zero in the specified
    direction. In addition to the guard callback, the event also has a "reset map"
    which is called when the event is triggered. The reset map may update any state
    component in the system.

    The event can also be defined as "terminal", which means that the simulation will
    terminate when the event is triggered. (TODO: Does the reset map still happen?)

    The "direction" of the zero-crossing is one of the following:
        - "none": Never trigger the event (can be useful for debugging)
        - "positive_then_non_positive": Trigger when the guard goes from positive to
            non-positive
        - "negative_then_non_negative": Trigger when the guard goes from negative to
            non-negative
        - "crosses_zero": Trigger when the guard crosses zero in either direction
        - "edge_detection": Trigger when the guard changes value

    Notes:
        This class should typically not need to be used directly by users. Instead,
        declare the guard function and reset map on a LeafSystem using the
        `declare_zero_crossing` method.  The event will then be auto-generated for
        simulation.
    """

    # Supersede type hints in Event with the specific signature for full-state updates
    callback: Callable[[ContextBase], LeafState] = None
    passthrough: Callable[[ContextBase], LeafState] = None

    guard: Callable[[ContextBase], Scalar] = None
    reset_map: dataclasses.InitVar[Callable[[ContextBase], LeafState]] = None
    direction: str = "crosses_zero"
    is_terminal: bool = False
    event_data: ZeroCrossingEventData = None

    # If not none, only trigger when in this mode. This logic is handled by the owning
    # leaf system.
    active_mode: int = None

    def __post_init__(self, reset_map):  # pylint: disable=arguments-differ
        if self.callback is None:
            self.callback = reset_map

    def _should_trigger(self, w0: Scalar, w1: Scalar) -> bool:
        """Determine if the event should trigger.

        This will use the provided beginning/ending guard value (w0 and w1, resp.),
        as well as the direction of the zero-crossing event. Additionally, the event
        will only trigger if it has been marked as "active", indicating for example
        that the system is in the correct "mode" or "stage" from which the event might
        trigger.
        """
        active = self.event_data.active

        trigger_func = _zero_crossing_trigger_functions[self.direction]
        return active & trigger_func(w0, w1)

    def should_trigger(self) -> bool:
        """Determine if the event should trigger based on the stored guard values."""
        return self._should_trigger(self.event_data.w0, self.event_data.w1)

    def handle(self, context: ContextBase) -> LeafState:
        """Conditionally compute the result of the zero crossing callback

        If the zero crossing is marked "inactive" via its event data attribute, the passthrough
        callback will be called instead of the update callback. Otherwise, the update
        callback will be called. The return types of both callbacks must match, but the
        specific type will depend on the kind of event.
        """
        if self.enable_tracing:  # not driven by simulator.enable_tracing.
            return cond(
                self.event_data.active & self.event_data.triggered,
                self.callback,
                self.passthrough,
                context,
            )

        # No tracing: use standard control flow
        if self.event_data.active & self.event_data.triggered:
            return self.callback(context)
        return self.passthrough(context)

    #
    # PyTree registration
    #
    def tree_flatten(self):
        children = (self.event_data,)
        aux_data = (
            self.system_id,
            self.guard,
            self.callback,
            self.name,
            self.direction,
            self.is_terminal,
            self.passthrough,
            self.enable_tracing,
            self.active_mode,
        )
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        (event_data,) = children
        (
            system_id,
            guard,
            callback,
            name,
            direction,
            is_terminal,
            passthrough,
            enable_tracing,
            active_mode,
        ) = aux_data
        return cls(
            system_id=system_id,
            event_data=event_data,
            guard=guard,
            callback=callback,
            name=name,
            direction=direction,
            is_terminal=is_terminal,
            passthrough=passthrough,
            enable_tracing=enable_tracing,
            active_mode=active_mode,
        )

handle(context)

Conditionally compute the result of the zero crossing callback

If the zero crossing is marked "inactive" via its event data attribute, the passthrough callback will be called instead of the update callback. Otherwise, the update callback will be called. The return types of both callbacks must match, but the specific type will depend on the kind of event.

Source code in collimator/framework/event.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def handle(self, context: ContextBase) -> LeafState:
    """Conditionally compute the result of the zero crossing callback

    If the zero crossing is marked "inactive" via its event data attribute, the passthrough
    callback will be called instead of the update callback. Otherwise, the update
    callback will be called. The return types of both callbacks must match, but the
    specific type will depend on the kind of event.
    """
    if self.enable_tracing:  # not driven by simulator.enable_tracing.
        return cond(
            self.event_data.active & self.event_data.triggered,
            self.callback,
            self.passthrough,
            context,
        )

    # No tracing: use standard control flow
    if self.event_data.active & self.event_data.triggered:
        return self.callback(context)
    return self.passthrough(context)

should_trigger()

Determine if the event should trigger based on the stored guard values.

Source code in collimator/framework/event.py
360
361
362
def should_trigger(self) -> bool:
    """Determine if the event should trigger based on the stored guard values."""
    return self._should_trigger(self.event_data.w0, self.event_data.w1)

next_dependency_ticket()

Create a new unique dependency ticket using the next available value.

Source code in collimator/framework/dependency_graph.py
75
76
77
def next_dependency_ticket():
    """Create a new unique dependency ticket using the next available value."""
    return DependencyTicket.next_available_ticket()

parameters(static=None, dynamic=None)

Decorator to apply to a system class to declare static or dynamic parameters.

Source code in collimator/framework/system_decorators.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def parameters(static: list[str] = None, dynamic: list[str] = None):
    """Decorator to apply to a system class to declare
    static or dynamic parameters."""

    if static is None:
        static = []

    if dynamic is None:
        dynamic = []

    static_param_names = set(static)
    dynamic_param_names = set(dynamic)

    def decorator(entity: Union[Callable[P, T], type]) -> Callable[P, T]:
        if isinstance(entity, type):
            init_func = entity.__init__
            # Useful for class introspection like parsing custom leaf system in
            # the frontend to configure the UI block.
            entity.__parameters__ = static + dynamic
        elif callable(entity):
            init_func = entity

        @wraps(init_func)
        def wrapped_init(self, *args, **kwargs):
            resolved_args = [
                arg.get() if isinstance(arg, Parameter) else arg for arg in args
            ]
            resolved_kwargs = {
                k: kwarg.get() if isinstance(kwarg, Parameter) else kwarg
                for k, kwarg in kwargs.items()
            }

            init_func(self, *resolved_args, **resolved_kwargs)

            # TODO: Prevent parameters from being inherited from parent systems.
            # This is necessary to avoid unknown behaviors when a child parameter
            # is used to define a parent parameter, eg. what we used to do in
            # PID continuous block where gains were used to calculate A, B, C, D
            # matrices.
            # This will force the implementor of the block to implement jitted
            # callbacks in such a way that they only depend on the current system's
            # parameters.
            # We should also allow inheritance of params with a flag or annotation.
            # https://github.com/collimator-ai/collimator/pull/6790
            # self._static_parameters = {}
            # self._dynamic_parameters = {}

            static_params = _get_params(static_param_names, init_func, args, kwargs)
            for param_name, value in static_params.items():
                self.declare_static_parameter(param_name, value)

            dyn_params = _get_params(dynamic_param_names, init_func, args, kwargs)
            for param_name, value in dyn_params.items():
                if value is not None:
                    self.declare_dynamic_parameter(param_name, value)

        if isinstance(entity, type):
            entity.__init__ = wrapped_init
            return entity
        elif callable(entity):
            return wrapped_init

    return decorator