Skip to content

Framework

collimator.framework

BlockInitializationError

Bases: CollimatorError

A generic error to be thrown when a block fails at init time, but the full exceptions are known to cause issues, eg. with ray serialization.

Source code in collimator/framework/error.py
225
226
227
228
229
230
class BlockInitializationError(CollimatorError):
    """A generic error to be thrown when a block fails at init time, but
    the full exceptions are known to cause issues, eg. with ray serialization.
    """

    pass

BlockParameterError

Bases: StaticError

Block parameters are missing or have invalid values.

Source code in collimator/framework/error.py
173
174
175
176
class BlockParameterError(StaticError):
    """Block parameters are missing or have invalid values."""

    pass

BlockRuntimeError

Bases: CollimatorError

A generic error to be thrown when a block fails at runtime, but the full exceptions are known to cause issues, eg. with ray serialization.

Source code in collimator/framework/error.py
239
240
241
242
243
244
class BlockRuntimeError(CollimatorError):
    """A generic error to be thrown when a block fails at runtime, but
    the full exceptions are known to cause issues, eg. with ray serialization.
    """

    pass

CollimatorError

Bases: Exception

Base class for all custom collimator errors.

Source code in collimator/framework/error.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class CollimatorError(Exception):
    """Base class for all custom collimator errors."""

    # Ideally we'd always have a system to pass but there are at least 2 cases
    # where we may not have one:
    # 1. parsing from json, the block hasn't been built yet
    # 2. other errors not specific to a block
    # In case 1, we should pass all name, path & ui_id info to the error

    def __init__(
        self,
        message=None,
        *,
        system: "SystemBase" = None,  # noqa
        system_id: Hashable = None,
        name_path: list[str] = None,
        ui_id_path: list[str] = None,
        port_index: int = None,
        port_name: str = None,
        port_direction: str = None,  # 'in' or 'out'
        parameter_name: str = None,
        loop: list["DirectedPortLocator"] = None,
    ):
        """Create a new CollimatorError.

        Only `message` is a positional argument, all others are keyword arguments.

        Args:
            message: A custom error message, defaults to the error class name.
            system: The system that the error occurred in, if available.
            system_id: The id of the system that the error occurred in, use if system can't be passed.
            name_path: The name path of the block that the error occurred in, use if system can't be passed.
            ui_id_path: The ui_id (uuid) path of the block that the error occurred in, use if system can't be passed.
            port_index: The index of the port that the error occurred at.
            port_name: The name of the port that the error occurred at.
            port_direction: The direction of the port that the error occurred at.
            parameter_name: The name of the parameter that the error occurred at.
            loop: A list of I/O ports where the error occurred (eg. AlgebraicLoopError).
        """
        super().__init__(message)

        if system and system_id:
            warnings.warn(
                "Should not specify both system and system_id when raising exceptions"
            )

        if system:
            self.system_id = system.system_id
            self.name_path = name_path or system.name_path
            self.ui_id_path = ui_id_path or system.ui_id_path
        else:
            self.system_id = system_id
            self.name_path = name_path
            self.ui_id_path = ui_id_path

        self.message = message
        self.port_index = port_index
        self.port_name = port_name
        self.port_direction = port_direction
        self.parameter_name = parameter_name

        # Extract serializable info from loop
        # NOTE: we could compact it a bit if the JSON becomes too large...
        self.loop: list[LoopItem] = None
        if loop is not None:
            self.loop = [
                LoopItem(
                    name_path=loc[0].name_path,
                    ui_id_path=loc[0].ui_id_path,
                    port_direction=loc[1],
                    port_index=loc[2],
                )
                for loc in loop
            ]

    def __str__(self):
        message = self.message or self.default_message
        return f"{message}{self._context_info()}"

    def _context_info(self) -> str:
        strbuf = []

        if self.name_path:
            # FIXME: this is known to be too verbose when looking at errors from
            # the UI but makes it better when running pytest or from code.
            # For now, be verbose.
            name_path = ".".join(self.name_path)
            strbuf.append(f" in block {name_path}")
        elif self.system_id:  # Unnamed blocks, likely from code
            strbuf.append(f" in system {self.system_id}")

        if self.port_direction:
            strbuf.append(
                f" at {self.port_direction}put port {self.port_name or self.port_index}"
            )
        elif self.port_name:
            strbuf.append(f" at port {self.port_name}")
        elif self.port_index is not None:
            strbuf.append(f" at port {self.port_index}")
        if self.parameter_name:
            strbuf.append(f" with parameter {self.parameter_name}")
        if self.__cause__ is not None:
            strbuf.append(f": {self.__cause__}")

        return "".join(strbuf)

    @property
    def block_name(self):
        if self.name_path is None:
            return None
        if len(self.name_path) == 0:
            return "root"
        return self.name_path[-1]

    @property
    def default_message(self):
        return type(self).__name__

    def caused_by(self, exc_type: type):
        """Check if this error is or was caused by another error type.

        For instance, if a CollimatorError is raised because of a TypeError,
        this method will return True when called with TypeError as exc_type.

        Args:
            exc_type: The type of exception to check for (eg. TypeError)

        Returns:
            bool: True if the error is or was caused by the given exception type.
        """

        def _is_or_caused_by(exc, cause_type) -> bool:
            if not exc or not cause_type:
                return False
            if isinstance(exc, cause_type):
                return True
            if not hasattr(self, "__cause__"):
                return False
            return _is_or_caused_by(exc.__cause__, cause_type)

        return _is_or_caused_by(self, exc_type)

__init__(message=None, *, system=None, system_id=None, name_path=None, ui_id_path=None, port_index=None, port_name=None, port_direction=None, parameter_name=None, loop=None)

Create a new CollimatorError.

Only message is a positional argument, all others are keyword arguments.

Parameters:

Name Type Description Default
message

A custom error message, defaults to the error class name.

None
system SystemBase

The system that the error occurred in, if available.

None
system_id Hashable

The id of the system that the error occurred in, use if system can't be passed.

None
name_path list[str]

The name path of the block that the error occurred in, use if system can't be passed.

None
ui_id_path list[str]

The ui_id (uuid) path of the block that the error occurred in, use if system can't be passed.

None
port_index int

The index of the port that the error occurred at.

None
port_name str

The name of the port that the error occurred at.

None
port_direction str

The direction of the port that the error occurred at.

None
parameter_name str

The name of the parameter that the error occurred at.

None
loop list[DirectedPortLocator]

A list of I/O ports where the error occurred (eg. AlgebraicLoopError).

None
Source code in collimator/framework/error.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(
    self,
    message=None,
    *,
    system: "SystemBase" = None,  # noqa
    system_id: Hashable = None,
    name_path: list[str] = None,
    ui_id_path: list[str] = None,
    port_index: int = None,
    port_name: str = None,
    port_direction: str = None,  # 'in' or 'out'
    parameter_name: str = None,
    loop: list["DirectedPortLocator"] = None,
):
    """Create a new CollimatorError.

    Only `message` is a positional argument, all others are keyword arguments.

    Args:
        message: A custom error message, defaults to the error class name.
        system: The system that the error occurred in, if available.
        system_id: The id of the system that the error occurred in, use if system can't be passed.
        name_path: The name path of the block that the error occurred in, use if system can't be passed.
        ui_id_path: The ui_id (uuid) path of the block that the error occurred in, use if system can't be passed.
        port_index: The index of the port that the error occurred at.
        port_name: The name of the port that the error occurred at.
        port_direction: The direction of the port that the error occurred at.
        parameter_name: The name of the parameter that the error occurred at.
        loop: A list of I/O ports where the error occurred (eg. AlgebraicLoopError).
    """
    super().__init__(message)

    if system and system_id:
        warnings.warn(
            "Should not specify both system and system_id when raising exceptions"
        )

    if system:
        self.system_id = system.system_id
        self.name_path = name_path or system.name_path
        self.ui_id_path = ui_id_path or system.ui_id_path
    else:
        self.system_id = system_id
        self.name_path = name_path
        self.ui_id_path = ui_id_path

    self.message = message
    self.port_index = port_index
    self.port_name = port_name
    self.port_direction = port_direction
    self.parameter_name = parameter_name

    # Extract serializable info from loop
    # NOTE: we could compact it a bit if the JSON becomes too large...
    self.loop: list[LoopItem] = None
    if loop is not None:
        self.loop = [
            LoopItem(
                name_path=loc[0].name_path,
                ui_id_path=loc[0].ui_id_path,
                port_direction=loc[1],
                port_index=loc[2],
            )
            for loc in loop
        ]

caused_by(exc_type)

Check if this error is or was caused by another error type.

For instance, if a CollimatorError is raised because of a TypeError, this method will return True when called with TypeError as exc_type.

Parameters:

Name Type Description Default
exc_type type

The type of exception to check for (eg. TypeError)

required

Returns:

Name Type Description
bool

True if the error is or was caused by the given exception type.

Source code in collimator/framework/error.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def caused_by(self, exc_type: type):
    """Check if this error is or was caused by another error type.

    For instance, if a CollimatorError is raised because of a TypeError,
    this method will return True when called with TypeError as exc_type.

    Args:
        exc_type: The type of exception to check for (eg. TypeError)

    Returns:
        bool: True if the error is or was caused by the given exception type.
    """

    def _is_or_caused_by(exc, cause_type) -> bool:
        if not exc or not cause_type:
            return False
        if isinstance(exc, cause_type):
            return True
        if not hasattr(self, "__cause__"):
            return False
        return _is_or_caused_by(exc.__cause__, cause_type)

    return _is_or_caused_by(self, exc_type)

ContextBase dataclass

Context object containing state, parameters, etc for a system.

NOTE: Type hints in ContextBase indicate the union between what would be returned by a LeafContext and a DiagramContext. See type hints of the subclasses for the specific argument and return types.

Attributes:

Name Type Description
owning_system SystemBase

The owning system of the context.

time Scalar

The time associated with the context. Will be None unless the context is the root context.

is_initialized bool

Flag indicating if the context is initialized. This should only be set by the ContextFactory during creation.

Source code in collimator/framework/context.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
@dataclasses.dataclass(frozen=True)
class ContextBase(metaclass=abc.ABCMeta):
    """Context object containing state, parameters, etc for a system.

    NOTE: Type hints in ContextBase indicate the union between what would be returned
    by a LeafContext and a DiagramContext. See type hints of the subclasses for
    the specific argument and return types.

    Attributes:
        owning_system (SystemBase):
            The owning system of the context.
        time (Scalar):
            The time associated with the context. Will be None unless the context
            is the root context.
        is_initialized (bool):
            Flag indicating if the context is initialized. This should only be set
            by the ContextFactory during creation.
    """

    owning_system: SystemBase
    time: Scalar = None
    is_initialized: bool = False
    parameters: Mapping[str, Array] = None

    @abc.abstractmethod
    def __getitem__(self, key: Hashable) -> LeafContext:
        """Get the subcontext associated with the given system ID.

        For leaf contexts, this will return `self`, but the method is provided
        so that there is a consistent interface for working with either an
        individual LeafSystem or tree-structured Diagram.

        For nested diagrams, intermediate diagrams do not have associated contexts,
        so indexing will fail.
        """
        pass

    @abc.abstractmethod
    def with_subcontext(self, key: Hashable, ctx: LeafContext) -> ContextBase:
        """Create a copy of this context, replacing the specified subcontext."""
        pass

    def with_time(self, value: Scalar) -> ContextBase:
        """Create a copy of this context, replacing time with the given value.

        This should only be called on the root context, since it is expected that all
        subcontexts will have a time value of None to avoid any conflicts.
        """
        return dataclasses.replace(self, time=value)

    @abc.abstractproperty
    def state(self) -> State:
        pass

    @abc.abstractmethod
    def with_state(self, state: State) -> ContextBase:
        """Create a copy of this context, replacing the entire state."""
        pass

    @abc.abstractproperty
    def continuous_state(self) -> StateComponent:
        pass

    @abc.abstractmethod
    def with_continuous_state(self, value: StateComponent) -> ContextBase:
        """Create a copy of this context, replacing the continuous state."""
        pass

    @abc.abstractproperty
    def num_continuous_states(self) -> int:
        pass

    @abc.abstractproperty
    def has_continuous_state(self) -> bool:
        pass

    @abc.abstractproperty
    def discrete_state(self) -> StateComponent:
        pass

    @abc.abstractmethod
    def with_discrete_state(self, value: StateComponent) -> ContextBase:
        """Create a copy of this context, replacing the discrete state."""
        pass

    @abc.abstractproperty
    def num_discrete_states(self) -> int:
        pass

    @abc.abstractproperty
    def has_discrete_state(self) -> bool:
        pass

    @abc.abstractproperty
    def mode(self) -> Mode:
        pass

    @abc.abstractproperty
    def has_mode(self) -> bool:
        pass

    @abc.abstractmethod
    def with_mode(self, value: Mode) -> ContextBase:
        """Create a copy of this context, replacing the mode."""
        pass

    def mark_initialized(self) -> ContextBase:
        return dataclasses.replace(self, is_initialized=True)

    @abc.abstractmethod
    def with_parameter(self, name: str, value: Parameter | ArrayLike) -> ContextBase:
        """Create a copy of this context, replacing the specified parameter."""
        pass

    @abc.abstractmethod
    def with_parameters(self, new_parameters: Mapping[str, Parameter | ArrayLike]) -> ContextBase:
        """Create a copy of this context, replacing only the specified parameters."""
        pass

    def _replace_param(self, name: str, value: ArrayLike):
        param = self.owning_system.default_parameters[name]
        param.set(value)
        self.parameters[name] = param.get()

__getitem__(key) abstractmethod

Get the subcontext associated with the given system ID.

For leaf contexts, this will return self, but the method is provided so that there is a consistent interface for working with either an individual LeafSystem or tree-structured Diagram.

For nested diagrams, intermediate diagrams do not have associated contexts, so indexing will fail.

Source code in collimator/framework/context.py
112
113
114
115
116
117
118
119
120
121
122
123
@abc.abstractmethod
def __getitem__(self, key: Hashable) -> LeafContext:
    """Get the subcontext associated with the given system ID.

    For leaf contexts, this will return `self`, but the method is provided
    so that there is a consistent interface for working with either an
    individual LeafSystem or tree-structured Diagram.

    For nested diagrams, intermediate diagrams do not have associated contexts,
    so indexing will fail.
    """
    pass

with_continuous_state(value) abstractmethod

Create a copy of this context, replacing the continuous state.

Source code in collimator/framework/context.py
151
152
153
154
@abc.abstractmethod
def with_continuous_state(self, value: StateComponent) -> ContextBase:
    """Create a copy of this context, replacing the continuous state."""
    pass

with_discrete_state(value) abstractmethod

Create a copy of this context, replacing the discrete state.

Source code in collimator/framework/context.py
168
169
170
171
@abc.abstractmethod
def with_discrete_state(self, value: StateComponent) -> ContextBase:
    """Create a copy of this context, replacing the discrete state."""
    pass

with_mode(value) abstractmethod

Create a copy of this context, replacing the mode.

Source code in collimator/framework/context.py
189
190
191
192
@abc.abstractmethod
def with_mode(self, value: Mode) -> ContextBase:
    """Create a copy of this context, replacing the mode."""
    pass

with_parameter(name, value) abstractmethod

Create a copy of this context, replacing the specified parameter.

Source code in collimator/framework/context.py
197
198
199
200
@abc.abstractmethod
def with_parameter(self, name: str, value: Parameter | ArrayLike) -> ContextBase:
    """Create a copy of this context, replacing the specified parameter."""
    pass

with_parameters(new_parameters) abstractmethod

Create a copy of this context, replacing only the specified parameters.

Source code in collimator/framework/context.py
202
203
204
205
@abc.abstractmethod
def with_parameters(self, new_parameters: Mapping[str, Parameter | ArrayLike]) -> ContextBase:
    """Create a copy of this context, replacing only the specified parameters."""
    pass

with_state(state) abstractmethod

Create a copy of this context, replacing the entire state.

Source code in collimator/framework/context.py
142
143
144
145
@abc.abstractmethod
def with_state(self, state: State) -> ContextBase:
    """Create a copy of this context, replacing the entire state."""
    pass

with_subcontext(key, ctx) abstractmethod

Create a copy of this context, replacing the specified subcontext.

Source code in collimator/framework/context.py
125
126
127
128
@abc.abstractmethod
def with_subcontext(self, key: Hashable, ctx: LeafContext) -> ContextBase:
    """Create a copy of this context, replacing the specified subcontext."""
    pass

with_time(value)

Create a copy of this context, replacing time with the given value.

This should only be called on the root context, since it is expected that all subcontexts will have a time value of None to avoid any conflicts.

Source code in collimator/framework/context.py
130
131
132
133
134
135
136
def with_time(self, value: Scalar) -> ContextBase:
    """Create a copy of this context, replacing time with the given value.

    This should only be called on the root context, since it is expected that all
    subcontexts will have a time value of None to avoid any conflicts.
    """
    return dataclasses.replace(self, time=value)

DependencyTicket

Singleton class for managing unique dependency tickets.

Source code in collimator/framework/dependency_graph.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class DependencyTicket:
    """Singleton class for managing unique dependency tickets."""

    nothing = 0  # Indicates "not dependent on anything".
    time = 1  # Time.
    xc = 2  # All continuous state variables.
    xd = 3  # All discrete state variables
    mode = 4  # All modes.
    x = 5  # All state variables x = {xc, xd, mode}.
    p = 6  # All parameters
    all_sources_except_input_ports = 7  # Everything except input ports.
    u = 8  # All input ports u.
    all_sources = 9  # All of the above.
    xcdot = 10  # Continuous state time derivative

    _next_available = 11  # This will get incremented by next_available_ticket().

    @classmethod
    def next_available_ticket(cls):
        cls._next_available += 1
        return cls._next_available

Diagram dataclass

Bases: SystemBase

Composite block-diagram representation of a dynamical system.

A Diagram is a collection of Systems connected together to form a larger hybrid dynamical system. Diagrams can be nested to any depth, creating a tree-structured block diagram.

NOTE: The Diagram class is not intended to be constructed directly. Instead, use the DiagramBuilder to construct a Diagram, which will pass the appropriate information to this constructor.

Source code in collimator/framework/diagram.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
@dataclasses.dataclass
class Diagram(SystemBase):
    """Composite block-diagram representation of a dynamical system.

    A Diagram is a collection of Systems connected together to form a larger hybrid
    dynamical system. Diagrams can be nested to any depth, creating a tree-structured
    block diagram.

    NOTE: The Diagram class is not intended to be constructed directly.  Instead,
    use the `DiagramBuilder` to construct a Diagram, which will pass the appropriate
    information to this constructor.
    """

    # Redefine here to make pylint happy
    system_id: Hashable = dataclasses.field(default_factory=next_system_id, init=False)
    name: str = None  # Human-readable name for this system (optional)
    ui_id: str = None  # UUID of the block when loaded from JSON (optional)

    # None of these attributes are intended to be modified or accessed directly after
    # construction.  Instead, use the interface defined by `SystemBase`.

    # Direct children of this Diagram
    nodes: List[SystemBase] = dataclasses.field(default_factory=list)

    # Mapping from input ports to output ports of child subsystems
    connection_map: Mapping[InputPortLocator, OutputPortLocator] = dataclasses.field(
        default_factory=dict,
    )

    # Optional identifier for "reference diagrams"
    ref_id: str = None

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self.name}, {len(self.nodes)} nodes)"

    def _pprint(self, prefix="") -> str:
        return f"{prefix}|-- {self.name}\n"

    def _pprint_helper(self, prefix="") -> str:
        repr = self._pprint(prefix=prefix)
        for i, substate in enumerate(self.nodes):
            repr += f"{substate._pprint_helper(prefix=f'{prefix}    ')}"
        return repr

    def __hash__(self) -> Hashable:
        return hash(self.system_id)

    def __getitem__(self, name: str) -> SystemBase:
        # Access by name - for convenient user interface only.  Programmatic
        #  access should use the nodes directly, e.g. `self.nodes[idx]`
        lookup = {node.name: node for node in self.nodes}
        return lookup[name]

    def __iter__(self) -> Iterator[SystemBase]:
        return iter(self.nodes)

    def __post_init__(self):
        super().__post_init__()

        # Set parent for all the immediate child systems
        for node in self.nodes:
            node.parent = self

        # The map of subsystem inputs/outputs to inputs/outputs of this Diagram.
        self._input_port_map: Mapping[InputPortLocator, int] = {}
        self._output_port_map: Mapping[OutputPortLocator, int] = {}

        # Also need the inverse output map, for determining feedthrough paths.
        self._inv_output_port_map: Mapping[int, OutputPortLocator] = {}

        # Leaves of the system tree (not necessarily the same as the direct
        # children of this Diagram, which may themselves be Diagrams)
        self.leaf_systems: List[LeafSystem] = []
        for sys in self.nodes:
            if isinstance(sys, Diagram):
                self.leaf_systems.extend(sys.leaf_systems)
                # No longer need the child leaf systems, since methods using this
                # should only be called from the top level.
                sys.leaf_systems = None
            else:
                self.leaf_systems.append(sys)

    def post_simulation_finalize(self) -> None:
        """Perform any post-simulation cleanup for this system."""
        for system in self.nodes:
            system.post_simulation_finalize()

    # Inherits docstrings from SystemBase
    @property
    def has_feedthrough_side_effects(self) -> bool:
        # See explanation in `SystemBase.has_feedthrough_side_effects`.
        return any(sys.has_feedthrough_side_effects for sys in self.nodes)

    # Inherits docstrings from SystemBase
    @property
    def has_ode_side_effects(self) -> bool:
        # Return true if either of the following are true:
        # 1. At least one subsystem has ODE side effects
        # 2. At least one subsystem has feedthrough side effects and the output
        #    ports of the diagram are used as ODE inputs.

        # If no subsystems have feedthrough side effects, we're done.
        if not self.has_feedthrough_side_effects:
            return False

        # If any subsystem is already known to have this property, we're done.
        if any(sys.has_ode_side_effects for sys in self.nodes):
            return True

        # If we get here, we need to actually test the dependency graph.
        for sys in self.nodes:
            if sys.has_feedthrough_side_effects:
                for port in sys.output_ports:
                    tracker = port.tracker
                    if tracker.is_prerequisite_of([DependencyTicket.xcdot]):
                        return True
        return False

    @property
    def has_continuous_state(self) -> bool:
        return any(sys.has_continuous_state for sys in self.nodes)

    @property
    def has_discrete_state(self) -> bool:
        return any(sys.has_discrete_state for sys in self.nodes)

    @property
    def has_zero_crossing_events(self) -> bool:
        return any(sys.has_zero_crossing_events for sys in self.nodes)

    @property
    def num_systems(self) -> int:
        # Number of subsystems _at this level_
        return len(self.nodes)

    def check_types(
        self,
        context: DiagramContext,
        error_collector: ErrorCollector = None,
    ) -> None:
        """Perform any system-specific static analysis."""
        for system in self.nodes:
            system.check_types(
                context,
                error_collector=error_collector,
            )

    #
    # Simulation interface
    #

    # Inherits docstrings from SystemBase
    def eval_time_derivatives(self, root_context: DiagramContext) -> List[Array]:
        leaf_systems = [
            subctx.owning_system for subctx in root_context.continuous_subcontexts
        ]
        return [sys.eval_time_derivatives(root_context) for sys in leaf_systems]

    #
    # Event handling
    #
    @property
    def state_update_events(self) -> FlatEventCollection:
        assert self.parent is None, (
            "Can only get periodic events from top-level Diagram, not "
            f"{self.system_id} with parent {self.parent.system_id}"
        )
        events = sum(
            [sys.state_update_events for sys in self.leaf_systems],
            start=FlatEventCollection(),
        )
        return events

    @property
    def zero_crossing_events(self) -> DiagramEventCollection:
        assert self.parent is None, (
            "Can only get zero-crossing events from top-level Diagram, not "
            f"{self.system_id} with parent {self.parent.system_id}"
        )
        return DiagramEventCollection(
            OrderedDict(
                {sys.system_id: sys.zero_crossing_events for sys in self.leaf_systems}
            )
        )

    # Inherits docstrings from SystemBase
    def determine_active_guards(
        self, root_context: DiagramContext
    ) -> DiagramEventCollection:
        assert self.parent is None, (
            "Can only get zero-crossing events from top-level Diagram, not "
            f"{self.system_id} with parent {self.parent.system_id}"
        )
        return DiagramEventCollection(
            OrderedDict(
                {
                    sys.system_id: sys.determine_active_guards(root_context)
                    for sys in self.leaf_systems
                }
            )
        )

    # Inherits docstrings from SystemBase
    def eval_zero_crossing_updates(
        self,
        root_context: DiagramContext,
        events: DiagramEventCollection,
    ) -> dict[Hashable, LeafState]:
        substates = OrderedDict()
        for system_id, subctx in root_context.subcontexts.items():
            sys = subctx.owning_system
            substates[system_id] = sys.eval_zero_crossing_updates(root_context, events)

        return substates

    #
    # I/O ports
    #
    @property
    def _flat_callbacks(self) -> List[SystemCallback]:
        """Return a flat list of all SystemCallbacks in the Diagram."""
        return [cb for sys in self.nodes for cb in sys._flat_callbacks]

    @property
    def exported_input_ports(self):
        return self._input_port_map

    @property
    def exported_output_ports(self):
        return self._output_port_map

    def eval_subsystem_input_port(
        self, context: DiagramContext, port_locator: InputPortLocator
    ) -> Array:
        """Evaluate the input port for a child of this system given the root context.

        Args:
            context (ContextBase): root context for this system
            port_locator (InputPortLocator): tuple of (system, port_index) identifying
                the input port to evaluate

        Returns:
            Array: Value returned from evaluating the subsystem port.

        Raises:
            InputNotConnectedError: if the input port is not connected
        """

        is_exported = port_locator in self._input_port_map
        if is_exported:
            # The upstream source is an input to this whole Diagram; evaluate that
            # input port and use the result as the value for this one.
            port_index = self._input_port_map[port_locator]  # Diagram-level index
            return self.input_ports[port_index].eval(context)  # Return upstream value

        is_connected = port_locator in self.connection_map
        if is_connected:
            # The upstream source is an output port of one of this Diagram's child
            # subsystems; evaluate the upstream output.
            upstream_locator = self.connection_map[port_locator]

            # This will return the value of the upstream port
            return self.eval_subsystem_output_port(context, upstream_locator)

        block, port_index = port_locator
        raise InputNotConnectedError(
            system=block,
            port_index=port_index,
            port_direction="in",
            message=f"Input port {block.name}[{port_index}] is not connected",
        )

    def eval_subsystem_output_port(
        self, context: DiagramContext, port_locator: OutputPortLocator
    ) -> Array:
        """ "Evaluate the output port for a child of this system given the root context.

        Args:
            context (ContextBase): root context for this system
            port_locator (OutputPortLocator): tuple of (system, port_index) identifying
                the output port to evaluate

        Returns:
            Array: Value returned from evaluating the subsystem port.
        """
        system, port_index = port_locator
        port = system.output_ports[port_index]

        # During simulation all we should need to do is evaluate the port.
        if context.is_initialized:
            return port.eval(context)

        # If the context is not initialized, we have to determine the signal data type.
        # In the easy case, the port has a default value, so we can just use that.
        if port.default_value is not None:
            logger.debug(
                "Using default output value of %s for %s",
                port.default_value,
                port_locator[0].name,
            )
            return port.default_value

        logger.debug(
            "Evaluating output port %s for system %s. Context initialized: %s",
            port_locator,
            port_locator[0].name,
            context.is_initialized,
        )

        # If there is no default value, try to evaluate the port to pull a "template"
        # value with an appropriate data type from upstream.  This will return None if
        # the port is not yet connected (e.g. if its upstream is an exported input of)
        # a Diagram, so we can defer evaluation.

        # Try again to evaluate the port
        val = port.eval(context)
        logger.debug(
            "  ---> %s returns %s", (port_locator[0].name, port_locator[1]), val
        )

        # If there is still no value, the port is not connected to anything.
        # Post-initialization this would be an error, but pre-initialization
        # it may be the case that the upstream is an exported input port of
        # the Diagram, so we can defer evaluation. Expect the block that is
        # doing this to handle the UpstreamEvalError appropriately.
        if val is None:
            system_name = system.name_path_str
            logger.debug(
                "Upstream evaluation of %s.out[%s] returned None. Deferring evaluation.",
                system_name,
                port_index,
            )
            raise UpstreamEvalError(port_locator=(system, "out", port_index))
        return val

    #
    # System-level declarations (should be done via DiagramBuilder)
    #
    def export_input(self, locator: InputPortLocator, port_name: str) -> int:
        """Export a subsystem input port as a diagram-level input.

        This should typically only be called during construction by DiagramBuilder.
        The standard workflow will be to call export_input on the _builder_ object,
        which will automatically call this method on the Diagram once created.

        Args:
            locator (InputPortLocator): tuple of (system, port_index) identifying
                the input port to export
            port_name (str): name of the new exported input port

        Returns:
            int: index of the exported input port in the diagram input_ports list
        """
        diagram_port_index = self.declare_input_port(name=port_name)
        self._input_port_map[locator] = diagram_port_index

        # Sometimes API calls will export ports manually (e.g. in the PID autotuning
        # workflow), so we need to make sure these dependencies are properly tracked.
        self.update_dependency_graph()

        return diagram_port_index

    def export_output(self, locator: OutputPortLocator, port_name: str) -> int:
        """Export a subsystem output port as a diagram-level output.

        This should typically only be called during construction by DiagramBuilder.
        The standard workflow will be to call export_input on the _builder_ object,
        which will automatically call this method on the Diagram once created.

        Args:
            locator (OutputPortLocator): tuple of (system, port_index) identifying
                the output port to export
            port_name (str): name of the new exported output port

        Returns:
            int: index of the exported output port in the diagram output_ports list
        """
        subsystem, subsystem_port_index = locator
        source_port = subsystem.output_ports[subsystem_port_index]
        diagram_port_index = self.declare_output_port(
            source_port.eval,
            name=port_name,
            prerequisites_of_calc=[source_port.ticket],
        )
        self._output_port_map[locator] = diagram_port_index
        self._inv_output_port_map[diagram_port_index] = locator

        # Sometimes API calls will export ports manually (e.g. in the PID autotuning
        # workflow), so we need to make sure these dependencies are properly tracked.
        self.update_dependency_graph()

        return diagram_port_index

    #
    # Initialization
    #
    @property
    def context_factory(self) -> DiagramContextFactory:
        return DiagramContextFactory(self)

    @property
    def dependency_graph_factory(self) -> DiagramDependencyGraphFactory:
        return DiagramDependencyGraphFactory(self)

    def initialize_static_data(self, context: DiagramContext) -> DiagramContext:
        """Perform any system-specific static analysis."""
        for system in self.nodes:
            context = system.initialize_static_data(context)
        return context

    def _has_feedthrough(self, input_port_index: int, output_port_index: int) -> bool:
        """Check if there is a direct-feedthrough path from the input port to the output port.

        Internal function used by `get_feedthrough`.  Should not typically need to
        be called directly.
        """
        # TODO: Would this be simpler if the input port map was inverted?
        input_ids = []
        for locator, index in self._input_port_map.items():
            if index == input_port_index:
                input_ids.append(locator)

        input_ids = set(input_ids)

        # Search graph for a direct-feedthrough connection from the output_port
        # to the input_port.  Maintain a set of the output port identifiers that
        # are known to have a direct-feedthrough path to the output_port
        active_set: Set[OutputPortLocator] = set()
        active_set.add(self._inv_output_port_map[output_port_index])

        while len(active_set) > 0:
            sys, sys_output = active_set.pop()
            for u, v in sys.get_feedthrough():
                if v == sys_output:
                    curr_input_id = (sys, u)
                    if curr_input_id in input_ids:
                        # Found a direct-feedthrough path to the input_port
                        return True
                    elif curr_input_id in self.connection_map:
                        # Intermediate input port has a direct-feedthrough path to
                        # output_port. Add the upstream output port (if there
                        # is one) to the active set.
                        active_set.add(self.connection_map[curr_input_id])

        # If there are no intermediate output ports with a direct-feedthrough path
        # to the output port, there is no direct feedthrough from the input port
        return False

    # Inherits docstring from SystemBase.get_feedthrough
    def get_feedthrough(self) -> List[Tuple[int, int]]:
        if self.feedthrough_pairs is not None:
            return self.feedthrough_pairs

        pairs = []
        for u in range(self.num_input_ports):
            for v in range(self.num_output_ports):
                if self._has_feedthrough(u, v):
                    pairs.append((u, v))

        self.feedthrough_pairs = pairs
        return self.feedthrough_pairs

    def find_system_with_path(self, path: str | list[str]) -> SystemBase:
        if isinstance(path, str):
            path = path.split(".")

        def _find_in_children():
            for child in self.nodes:
                if child.name == path[0]:
                    if len(path) == 1:
                        return child
                    if isinstance(child, Diagram):
                        return child.find_system_with_path(path[1:])
                    return None
            return None

        if self.parent is None:
            return _find_in_children()

        if self.name == path[0] and len(path) == 1:
            return self

        return _find_in_children()

    def declare_parameter(self, name: str, parameter: Parameter) -> None:
        """Declare a parameter for this system.

        Parameters:
            name (str): The name of the parameter.
            parameter (Parameter): The parameter object.
        """
        # Force the parameter to have the correct name, all diagram parameters
        # should be named.
        parameter.name = name
        super().declare_parameter(name, parameter)

check_types(context, error_collector=None)

Perform any system-specific static analysis.

Source code in collimator/framework/diagram.py
185
186
187
188
189
190
191
192
193
194
195
def check_types(
    self,
    context: DiagramContext,
    error_collector: ErrorCollector = None,
) -> None:
    """Perform any system-specific static analysis."""
    for system in self.nodes:
        system.check_types(
            context,
            error_collector=error_collector,
        )

declare_parameter(name, parameter)

Declare a parameter for this system.

Parameters:

Name Type Description Default
name str

The name of the parameter.

required
parameter Parameter

The parameter object.

required
Source code in collimator/framework/diagram.py
534
535
536
537
538
539
540
541
542
543
544
def declare_parameter(self, name: str, parameter: Parameter) -> None:
    """Declare a parameter for this system.

    Parameters:
        name (str): The name of the parameter.
        parameter (Parameter): The parameter object.
    """
    # Force the parameter to have the correct name, all diagram parameters
    # should be named.
    parameter.name = name
    super().declare_parameter(name, parameter)

eval_subsystem_input_port(context, port_locator)

Evaluate the input port for a child of this system given the root context.

Parameters:

Name Type Description Default
context ContextBase

root context for this system

required
port_locator InputPortLocator

tuple of (system, port_index) identifying the input port to evaluate

required

Returns:

Name Type Description
Array Array

Value returned from evaluating the subsystem port.

Raises:

Type Description
InputNotConnectedError

if the input port is not connected

Source code in collimator/framework/diagram.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def eval_subsystem_input_port(
    self, context: DiagramContext, port_locator: InputPortLocator
) -> Array:
    """Evaluate the input port for a child of this system given the root context.

    Args:
        context (ContextBase): root context for this system
        port_locator (InputPortLocator): tuple of (system, port_index) identifying
            the input port to evaluate

    Returns:
        Array: Value returned from evaluating the subsystem port.

    Raises:
        InputNotConnectedError: if the input port is not connected
    """

    is_exported = port_locator in self._input_port_map
    if is_exported:
        # The upstream source is an input to this whole Diagram; evaluate that
        # input port and use the result as the value for this one.
        port_index = self._input_port_map[port_locator]  # Diagram-level index
        return self.input_ports[port_index].eval(context)  # Return upstream value

    is_connected = port_locator in self.connection_map
    if is_connected:
        # The upstream source is an output port of one of this Diagram's child
        # subsystems; evaluate the upstream output.
        upstream_locator = self.connection_map[port_locator]

        # This will return the value of the upstream port
        return self.eval_subsystem_output_port(context, upstream_locator)

    block, port_index = port_locator
    raise InputNotConnectedError(
        system=block,
        port_index=port_index,
        port_direction="in",
        message=f"Input port {block.name}[{port_index}] is not connected",
    )

eval_subsystem_output_port(context, port_locator)

"Evaluate the output port for a child of this system given the root context.

Parameters:

Name Type Description Default
context ContextBase

root context for this system

required
port_locator OutputPortLocator

tuple of (system, port_index) identifying the output port to evaluate

required

Returns:

Name Type Description
Array Array

Value returned from evaluating the subsystem port.

Source code in collimator/framework/diagram.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def eval_subsystem_output_port(
    self, context: DiagramContext, port_locator: OutputPortLocator
) -> Array:
    """ "Evaluate the output port for a child of this system given the root context.

    Args:
        context (ContextBase): root context for this system
        port_locator (OutputPortLocator): tuple of (system, port_index) identifying
            the output port to evaluate

    Returns:
        Array: Value returned from evaluating the subsystem port.
    """
    system, port_index = port_locator
    port = system.output_ports[port_index]

    # During simulation all we should need to do is evaluate the port.
    if context.is_initialized:
        return port.eval(context)

    # If the context is not initialized, we have to determine the signal data type.
    # In the easy case, the port has a default value, so we can just use that.
    if port.default_value is not None:
        logger.debug(
            "Using default output value of %s for %s",
            port.default_value,
            port_locator[0].name,
        )
        return port.default_value

    logger.debug(
        "Evaluating output port %s for system %s. Context initialized: %s",
        port_locator,
        port_locator[0].name,
        context.is_initialized,
    )

    # If there is no default value, try to evaluate the port to pull a "template"
    # value with an appropriate data type from upstream.  This will return None if
    # the port is not yet connected (e.g. if its upstream is an exported input of)
    # a Diagram, so we can defer evaluation.

    # Try again to evaluate the port
    val = port.eval(context)
    logger.debug(
        "  ---> %s returns %s", (port_locator[0].name, port_locator[1]), val
    )

    # If there is still no value, the port is not connected to anything.
    # Post-initialization this would be an error, but pre-initialization
    # it may be the case that the upstream is an exported input port of
    # the Diagram, so we can defer evaluation. Expect the block that is
    # doing this to handle the UpstreamEvalError appropriately.
    if val is None:
        system_name = system.name_path_str
        logger.debug(
            "Upstream evaluation of %s.out[%s] returned None. Deferring evaluation.",
            system_name,
            port_index,
        )
        raise UpstreamEvalError(port_locator=(system, "out", port_index))
    return val

export_input(locator, port_name)

Export a subsystem input port as a diagram-level input.

This should typically only be called during construction by DiagramBuilder. The standard workflow will be to call export_input on the builder object, which will automatically call this method on the Diagram once created.

Parameters:

Name Type Description Default
locator InputPortLocator

tuple of (system, port_index) identifying the input port to export

required
port_name str

name of the new exported input port

required

Returns:

Name Type Description
int int

index of the exported input port in the diagram input_ports list

Source code in collimator/framework/diagram.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def export_input(self, locator: InputPortLocator, port_name: str) -> int:
    """Export a subsystem input port as a diagram-level input.

    This should typically only be called during construction by DiagramBuilder.
    The standard workflow will be to call export_input on the _builder_ object,
    which will automatically call this method on the Diagram once created.

    Args:
        locator (InputPortLocator): tuple of (system, port_index) identifying
            the input port to export
        port_name (str): name of the new exported input port

    Returns:
        int: index of the exported input port in the diagram input_ports list
    """
    diagram_port_index = self.declare_input_port(name=port_name)
    self._input_port_map[locator] = diagram_port_index

    # Sometimes API calls will export ports manually (e.g. in the PID autotuning
    # workflow), so we need to make sure these dependencies are properly tracked.
    self.update_dependency_graph()

    return diagram_port_index

export_output(locator, port_name)

Export a subsystem output port as a diagram-level output.

This should typically only be called during construction by DiagramBuilder. The standard workflow will be to call export_input on the builder object, which will automatically call this method on the Diagram once created.

Parameters:

Name Type Description Default
locator OutputPortLocator

tuple of (system, port_index) identifying the output port to export

required
port_name str

name of the new exported output port

required

Returns:

Name Type Description
int int

index of the exported output port in the diagram output_ports list

Source code in collimator/framework/diagram.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
def export_output(self, locator: OutputPortLocator, port_name: str) -> int:
    """Export a subsystem output port as a diagram-level output.

    This should typically only be called during construction by DiagramBuilder.
    The standard workflow will be to call export_input on the _builder_ object,
    which will automatically call this method on the Diagram once created.

    Args:
        locator (OutputPortLocator): tuple of (system, port_index) identifying
            the output port to export
        port_name (str): name of the new exported output port

    Returns:
        int: index of the exported output port in the diagram output_ports list
    """
    subsystem, subsystem_port_index = locator
    source_port = subsystem.output_ports[subsystem_port_index]
    diagram_port_index = self.declare_output_port(
        source_port.eval,
        name=port_name,
        prerequisites_of_calc=[source_port.ticket],
    )
    self._output_port_map[locator] = diagram_port_index
    self._inv_output_port_map[diagram_port_index] = locator

    # Sometimes API calls will export ports manually (e.g. in the PID autotuning
    # workflow), so we need to make sure these dependencies are properly tracked.
    self.update_dependency_graph()

    return diagram_port_index

initialize_static_data(context)

Perform any system-specific static analysis.

Source code in collimator/framework/diagram.py
454
455
456
457
458
def initialize_static_data(self, context: DiagramContext) -> DiagramContext:
    """Perform any system-specific static analysis."""
    for system in self.nodes:
        context = system.initialize_static_data(context)
    return context

post_simulation_finalize()

Perform any post-simulation cleanup for this system.

Source code in collimator/framework/diagram.py
132
133
134
135
def post_simulation_finalize(self) -> None:
    """Perform any post-simulation cleanup for this system."""
    for system in self.nodes:
        system.post_simulation_finalize()

DiagramBuilder

Class for constructing block diagram systems.

The DiagramBuilder class is responsible for building a diagram by adding systems, connecting ports, and exporting inputs and outputs. It keeps track of the registered systems, input and output ports, and the connection map between input and output ports of the child systems.

Source code in collimator/framework/diagram_builder.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
class DiagramBuilder:
    """Class for constructing block diagram systems.

    The `DiagramBuilder` class is responsible for building a diagram by adding systems, connecting ports,
    and exporting inputs and outputs. It keeps track of the registered systems, input and output ports,
    and the connection map between input and output ports of the child systems.
    """

    def __init__(self):
        # Child input ports that are exported as diagram-level inputs
        self._input_port_ids: List[InputPortLocator] = []
        self._input_port_names: List[str] = []
        # Child output ports that are exported as diagram-level outputs
        self._output_port_ids: List[OutputPortLocator] = []
        self._output_port_names: List[str] = []

        # Connection map between input and output ports of the child systems
        self._connection_map: Mapping[InputPortLocator, OutputPortLocator] = {}

        # List of registered systems
        self._registered_systems: List[SystemBase] = []

        # Name lookup for input ports
        self._diagram_input_indices: Mapping[str, InputPortLocator] = {}

        # All input ports of child systems (for use in ensuring proper connectivity)
        self._all_input_ports: List[InputPortLocator] = []

        # Each DiagramBuilder can only be used to build a single diagram.  This is to
        # avoid creating multiple diagrams that reference the same LeafSystem. Doing so
        # may or may not actually lead to problems, since the LeafSystems themselves
        # should act like a collection of pure functions, but best practice is to have
        # each leaf system be fully unique.
        self._already_built = False

    def add(self, *systems: SystemBase) -> List[SystemBase] | SystemBase:
        """Add one or more systems to the diagram.

        Args:
            *systems SystemBase:
                System(s) to add to the diagram.

        Returns:
            List[SystemBase] | SystemBase:
                The added system(s). Will return a single system if there is only
                a single system in the argument list.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the system is already registered.
            BuilderError: If the system name is not unique.
        """
        for system in systems:
            self._check_not_already_built()
            self._check_system_not_registered(system)
            self._check_system_name_is_unique(system)
            self._registered_systems.append(system)

            # Make sure the child system has a dependency graph so that we can
            # subscribe trackers to its ports and callbacks.  If the system already
            # has a dependency graph, this will do nothing.
            system.create_dependency_graph()

            # Add the system's input ports to the list of all input ports
            # So that we can make sure they're all connected before building.
            self._all_input_ports.extend([port.locator for port in system.input_ports])

            logger.debug(f"Added system {system.name} to DiagramBuilder")
            logger.debug(
                f"    Registered systems: {[s.name for s in self._registered_systems]}"
            )

        return systems[0] if len(systems) == 1 else systems

    def connect(self, src: OutputPort, dest: InputPort):
        """Connect an output port to an input port.

        The input port and output port must both belong to systems that have
        already been added to the diagram.  The input port must not already be
        connected to another output port.

        Args:
            src (OutputPort): The output port to connect.
            dest (InputPort): The input port to connect.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the source system is not registered.
            BuilderError: If the destination system is not registered.
            BuilderError: If the input port is already connected.
        """
        self._check_not_already_built()
        self._check_system_is_registered(src.system)
        self._check_system_is_registered(dest.system)
        self._check_input_not_connected(dest.locator)

        self._connection_map[dest.locator] = src.locator

        logger.debug(
            f"Connected port {src.name} of system {src.system.name} to port {dest.name} of system {dest.system.name}"
        )
        logger.debug(f"Connection map so far: {self._connection_map}")

    def export_input(self, port: InputPort, name: str = None) -> int:
        """Export an input port of a child system as a diagram-level input.

        The input port must belong to a system that has already been added to the
        diagram. The input port must not already be connected to another output port.

        Args:
            port (InputPort): The input port to export.
            name (str, optional):
                The name to assign to the exported input port. If not provided, a
                unique name will be generated.

        Returns:
            int: The index (in the to-be-built diagram) of the exported input port.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the system is not registered.
            BuilderError: If the input port is already connected.
            BuilderError: If the input port name is not unique.
        """
        self._check_not_already_built()
        self._check_system_is_registered(port.system)
        self._check_input_not_connected(port.locator)

        if name is None:
            # Since the system names are unique, auto-generated port names are also unique
            # at the level of _this_ diagram (subsystems can have ports with the same name)
            name = f"{port.system.name}_{port.name}"
        elif name in self._diagram_input_indices:
            raise BuilderError(
                f"Input port name {name} is not unique",
                system=port.system,
                port_index=port.index,
                port_direction="in",
            )

        # Index at the diagram (not subsystem) level
        port_index = len(self._input_port_ids)
        self._input_port_ids.append(port.locator)
        self._input_port_names.append(name)

        self._diagram_input_indices[name] = port_index

        return port_index

    def export_output(self, port: OutputPort, name: str = None) -> int:
        """Export an output port of a child system as a diagram-level output.

        The output port must belong to a system that has already been added to the
        diagram.

        Args:
            port (OutputPort): The output port to export.
            name (str, optional):
                The name to assign to the exported output port. If not provided, a
                unique name will be generated.

        Returns:
            int: The index (in the to-be-built diagram) of the exported output port.

        Raises:
            BuilderError: If the diagram has already been built.
            BuilderError: If the system is not registered.
            BuilderError: If the output port name is not unique.
        """
        self._check_not_already_built()
        self._check_system_is_registered(port.system)

        if name is None:
            # Since the system names are unique, auto-generated port names are also unique
            # at the level of _this_ diagram (subsystems can have ports with the same name)
            name = f"{port.system.name}_{port.name}"
        elif name in self._output_port_names:
            raise BuilderError(
                f"Output port name {name} is not unique",
                system=port.system,
                port_index=port.index,
                port_direction="out",
            )

        # Index at the diagram (not subsystem) level
        port_index = len(self._output_port_ids)
        self._output_port_ids.append(port.locator)
        self._output_port_names.append(name)

        return port_index

    def _check_not_already_built(self):
        if self._already_built:
            raise BuilderError(
                "DiagramBuilder: build has already been called to "
                "create a diagram; this DiagramBuilder may no longer be used."
            )

    def _check_system_name_is_unique(self, system: SystemBase):
        if system.name in map(lambda s: s.name, self._registered_systems):
            raise SystemNameNotUniqueError(system)

    def _system_is_registered(self, system: SystemBase) -> bool:
        # return (system is not None) and (system in self._registered_systems)
        if system.system_id is None:  # system.__init__ is not done yet
            return False
        return system.system_id in map(lambda s: s.system_id, self._registered_systems)

    def _check_system_not_registered(self, system: SystemBase):
        if self._system_is_registered(system):
            raise BuilderError(
                f"System {system.name} is already registered",
                system=system,
            )

    def _check_system_is_registered(self, system: SystemBase):
        if not self._system_is_registered(system):
            raise BuilderError(
                f"System {system.name} is not registered",
                system=system,
            )

    def _check_input_not_connected(self, input_port_locator: InputPortLocator):
        if not (
            (input_port_locator not in self._input_port_ids)
            and (input_port_locator not in self._connection_map)
        ):
            system, port_index = input_port_locator
            raise BuilderError(
                f"Input port {port_index} for {system} is already connected",
                system=system,
                port_index=port_index,
                port_direction="in",
            )

    def _check_input_is_connected(self, input_port_locator: InputPortLocator):
        if not (
            (input_port_locator in self._input_port_ids)
            or (input_port_locator in self._connection_map)
        ):
            raise DisconnectedInputError(input_port_locator)

    def _check_no_algebraic_loops(self, name: str):
        """Check for algebraic loops in the diagram.

        This is a more or less direct port of the Drake method
        DiagramBuilder::ThrowIfAlgebraicLoopExists. Some comments are verbatim
        explanations of the algorithm implemented there.
        """

        # The nodes in the graph are the input/output ports defined as part of
        # the diagram's internal connections.  Ports that are not internally
        # connected cannot participate in a cycle at this level, so we don't include them
        # in the nodes set.
        nodes: Set[PortBase] = set()

        # For each `value` in `edges[key]`, the `key` directly influences `value`.
        edges: Mapping[PortBase, Set[PortBase]] = {}

        # Add the diagram's internal connections to the digraph nodes and edges
        for input_port_locator, output_port_locator in self._connection_map.items():
            # Directly using the port locator does not result in a unique identifier
            # since (sys, 0) represents both input port 0 and output port 0.  Instead,
            # use the port directly as a key, since it is a unique hashable object.
            input_system, input_index = input_port_locator
            input_port = input_system.input_ports[input_index]
            logger.debug(f"Adding locator {input_port} to nodes")
            nodes.add(input_port)

            output_system, output_index = output_port_locator
            output_port = output_system.output_ports[output_index]
            logger.debug(f"Adding locator {output_port} to nodes")
            nodes.add(output_port)

            if output_port not in edges:
                edges[output_port] = set()

            logger.debug(f"Adding edge[{output_port}] = {input_port}")
            edges[output_port].add(input_port)

        # Add more edges based on each System's direct feedthrough.
        # input -> output port iff there is direct feedthrough from input -> output
        # If a feedthrough edge refers to a port not in `nodes`, omit it because ports
        # that are not connected inside the diagram cannot participate in a cycle at
        # the level of this diagram (higher-level diagrams will test for cycles at
        # their level).
        for system in self._registered_systems:
            logger.debug(f"Checking feedthrough for system {system.name}")
            for input_index, output_index in system.get_feedthrough():
                input_port = system.input_ports[input_index]
                output_port = system.output_ports[output_index]
                logger.debug(f"Feedthrough from {input_port} to {output_port}")
                if input_port in nodes and output_port in nodes:
                    if input_port not in edges:
                        edges[input_port] = set()
                    edges[input_port].add(output_port)

        def _graph_has_cycle(
            node: PortBase,
            visited: Set[DirectedPortLocator],
            stack: List[DirectedPortLocator],
        ) -> bool:
            # Helper to do the algebraic loop test by depth-first search on the graph
            # to find cycles. Modifies `visited` and `stack` in place.

            logger.debug(f"Checking node {node}")

            assert node.directed_locator not in visited
            visited.add(node.directed_locator)

            if node in edges:
                assert node not in stack
                stack.append(node.directed_locator)
                edge_iter = edges[node]
                for target in edge_iter:
                    if target.directed_locator not in visited and _graph_has_cycle(
                        target, visited, stack
                    ):
                        logger.debug(f"Found cycle at {target}")
                        return True
                    elif target.directed_locator in stack:
                        logger.debug(f"Found target {target} in stack {stack}")
                        return True
                stack.pop()

            # If we get this far there is no cycle
            return False

        # Evaluate the graph for cycles
        visited: Set[DirectedPortLocator] = set()
        stack: List[DirectedPortLocator] = []
        for node in nodes:
            if node.directed_locator in visited:
                continue
            if _graph_has_cycle(node, visited, stack):
                raise AlgebraicLoopError(name, stack)

    def _check_contents_are_complete(self):
        # Make sure all the systems referenced in the builder attributes are registered

        # Check that systems and registered_systems have the same elements
        for system in self._registered_systems:
            self._check_system_is_registered(system)

        # Check that connection_map only refers to registered systems
        for (
            input_port_locator,
            output_port_locator,
        ) in self._connection_map.items():
            self._check_system_is_registered(input_port_locator[0])
            self._check_system_is_registered(output_port_locator[0])

        # Check that input_port_ids and output_port_ids only refer to registered systems
        for port_locator in [*self._input_port_ids, *self._output_port_ids]:
            self._check_system_is_registered(port_locator[0])

    def _check_ports_are_valid(self):
        for dst, src in self._connection_map.items():
            dst_sys, dst_idx = dst
            if (dst_idx < 0) or (dst_idx >= dst_sys.num_input_ports):
                raise BuilderError(
                    f"Input port index {dst_idx} is out of range for system {dst_sys.name}",
                    system=dst_sys,
                )
            src_sys, src_idx = src
            if (src_idx < 0) or (src_idx >= src_sys.num_output_ports):
                raise BuilderError(
                    f"Output port index {src_idx} is out of range for system {src_sys.name}",
                    system=src_sys,
                )

    def build(
        self,
        name: str = "root",
        ui_id: str = None,
        parameters: dict[str, Parameter] = None
    ) -> Diagram:
        """Builds a Diagram system with the specified name and system ID.

        Args:
            name (str, optional): The name of the diagram. Defaults to "root".
            system_id (Hashable, optional):
                The system ID of the diagram. Defaults to None, which
                autogenerates a unique integer ID.

        Returns:
            Diagram: The newly constructed diagram.

        Raises:
            EmptyDiagramError: If no systems are registered in the diagram.
            BuilderError: If the diagram has already been built.
            AlgebraicLoopError: If an algebraic loop is detected in the diagram.
            DisconnectedInputError: If an input port is not connected.
        """
        self._check_not_already_built()
        self._check_contents_are_complete()
        self._check_ports_are_valid()

        # Check that all internal input ports are connected
        for input_port_locator in self._input_port_ids:
            self._check_input_is_connected(input_port_locator)

        logger.debug(f"DiagramBuilder: checking for algebraic loops in {name}")
        self._check_no_algebraic_loops(name)
        logger.debug(f"DiagramBuilder: no algebraic loops detected in {name}")

        if len(self._registered_systems) == 0:
            raise EmptyDiagramError(name)

        diagram = Diagram(
            nodes=self._registered_systems,
            name=name,
            connection_map=self._connection_map,
            ui_id=ui_id,
        )

        if parameters:
            for name, parameter in parameters.items():
                diagram.declare_parameter(name, parameter)

        # Export diagram-level inputs
        for locator, port_name in zip(self._input_port_ids, self._input_port_names):
            diagram.export_input(locator, port_name)

        # Export diagram-level outputs
        assert len(self._output_port_ids) == len(self._output_port_names)
        for locator, port_name in zip(self._output_port_ids, self._output_port_names):
            diagram.export_output(locator, port_name)

        # Create the dependency graph for the diagram
        diagram.create_dependency_graph()

        self._already_built = True  # Prevent further use of this builder
        return diagram

add(*systems)

Add one or more systems to the diagram.

Parameters:

Name Type Description Default
*systems SystemBase

System(s) to add to the diagram.

()

Returns:

Type Description
List[SystemBase] | SystemBase

List[SystemBase] | SystemBase: The added system(s). Will return a single system if there is only a single system in the argument list.

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the system is already registered.

BuilderError

If the system name is not unique.

Source code in collimator/framework/diagram_builder.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def add(self, *systems: SystemBase) -> List[SystemBase] | SystemBase:
    """Add one or more systems to the diagram.

    Args:
        *systems SystemBase:
            System(s) to add to the diagram.

    Returns:
        List[SystemBase] | SystemBase:
            The added system(s). Will return a single system if there is only
            a single system in the argument list.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the system is already registered.
        BuilderError: If the system name is not unique.
    """
    for system in systems:
        self._check_not_already_built()
        self._check_system_not_registered(system)
        self._check_system_name_is_unique(system)
        self._registered_systems.append(system)

        # Make sure the child system has a dependency graph so that we can
        # subscribe trackers to its ports and callbacks.  If the system already
        # has a dependency graph, this will do nothing.
        system.create_dependency_graph()

        # Add the system's input ports to the list of all input ports
        # So that we can make sure they're all connected before building.
        self._all_input_ports.extend([port.locator for port in system.input_ports])

        logger.debug(f"Added system {system.name} to DiagramBuilder")
        logger.debug(
            f"    Registered systems: {[s.name for s in self._registered_systems]}"
        )

    return systems[0] if len(systems) == 1 else systems

build(name='root', ui_id=None, parameters=None)

Builds a Diagram system with the specified name and system ID.

Parameters:

Name Type Description Default
name str

The name of the diagram. Defaults to "root".

'root'
system_id Hashable

The system ID of the diagram. Defaults to None, which autogenerates a unique integer ID.

required

Returns:

Name Type Description
Diagram Diagram

The newly constructed diagram.

Raises:

Type Description
EmptyDiagramError

If no systems are registered in the diagram.

BuilderError

If the diagram has already been built.

AlgebraicLoopError

If an algebraic loop is detected in the diagram.

DisconnectedInputError

If an input port is not connected.

Source code in collimator/framework/diagram_builder.py
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def build(
    self,
    name: str = "root",
    ui_id: str = None,
    parameters: dict[str, Parameter] = None
) -> Diagram:
    """Builds a Diagram system with the specified name and system ID.

    Args:
        name (str, optional): The name of the diagram. Defaults to "root".
        system_id (Hashable, optional):
            The system ID of the diagram. Defaults to None, which
            autogenerates a unique integer ID.

    Returns:
        Diagram: The newly constructed diagram.

    Raises:
        EmptyDiagramError: If no systems are registered in the diagram.
        BuilderError: If the diagram has already been built.
        AlgebraicLoopError: If an algebraic loop is detected in the diagram.
        DisconnectedInputError: If an input port is not connected.
    """
    self._check_not_already_built()
    self._check_contents_are_complete()
    self._check_ports_are_valid()

    # Check that all internal input ports are connected
    for input_port_locator in self._input_port_ids:
        self._check_input_is_connected(input_port_locator)

    logger.debug(f"DiagramBuilder: checking for algebraic loops in {name}")
    self._check_no_algebraic_loops(name)
    logger.debug(f"DiagramBuilder: no algebraic loops detected in {name}")

    if len(self._registered_systems) == 0:
        raise EmptyDiagramError(name)

    diagram = Diagram(
        nodes=self._registered_systems,
        name=name,
        connection_map=self._connection_map,
        ui_id=ui_id,
    )

    if parameters:
        for name, parameter in parameters.items():
            diagram.declare_parameter(name, parameter)

    # Export diagram-level inputs
    for locator, port_name in zip(self._input_port_ids, self._input_port_names):
        diagram.export_input(locator, port_name)

    # Export diagram-level outputs
    assert len(self._output_port_ids) == len(self._output_port_names)
    for locator, port_name in zip(self._output_port_ids, self._output_port_names):
        diagram.export_output(locator, port_name)

    # Create the dependency graph for the diagram
    diagram.create_dependency_graph()

    self._already_built = True  # Prevent further use of this builder
    return diagram

connect(src, dest)

Connect an output port to an input port.

The input port and output port must both belong to systems that have already been added to the diagram. The input port must not already be connected to another output port.

Parameters:

Name Type Description Default
src OutputPort

The output port to connect.

required
dest InputPort

The input port to connect.

required

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the source system is not registered.

BuilderError

If the destination system is not registered.

BuilderError

If the input port is already connected.

Source code in collimator/framework/diagram_builder.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def connect(self, src: OutputPort, dest: InputPort):
    """Connect an output port to an input port.

    The input port and output port must both belong to systems that have
    already been added to the diagram.  The input port must not already be
    connected to another output port.

    Args:
        src (OutputPort): The output port to connect.
        dest (InputPort): The input port to connect.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the source system is not registered.
        BuilderError: If the destination system is not registered.
        BuilderError: If the input port is already connected.
    """
    self._check_not_already_built()
    self._check_system_is_registered(src.system)
    self._check_system_is_registered(dest.system)
    self._check_input_not_connected(dest.locator)

    self._connection_map[dest.locator] = src.locator

    logger.debug(
        f"Connected port {src.name} of system {src.system.name} to port {dest.name} of system {dest.system.name}"
    )
    logger.debug(f"Connection map so far: {self._connection_map}")

export_input(port, name=None)

Export an input port of a child system as a diagram-level input.

The input port must belong to a system that has already been added to the diagram. The input port must not already be connected to another output port.

Parameters:

Name Type Description Default
port InputPort

The input port to export.

required
name str

The name to assign to the exported input port. If not provided, a unique name will be generated.

None

Returns:

Name Type Description
int int

The index (in the to-be-built diagram) of the exported input port.

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the system is not registered.

BuilderError

If the input port is already connected.

BuilderError

If the input port name is not unique.

Source code in collimator/framework/diagram_builder.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def export_input(self, port: InputPort, name: str = None) -> int:
    """Export an input port of a child system as a diagram-level input.

    The input port must belong to a system that has already been added to the
    diagram. The input port must not already be connected to another output port.

    Args:
        port (InputPort): The input port to export.
        name (str, optional):
            The name to assign to the exported input port. If not provided, a
            unique name will be generated.

    Returns:
        int: The index (in the to-be-built diagram) of the exported input port.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the system is not registered.
        BuilderError: If the input port is already connected.
        BuilderError: If the input port name is not unique.
    """
    self._check_not_already_built()
    self._check_system_is_registered(port.system)
    self._check_input_not_connected(port.locator)

    if name is None:
        # Since the system names are unique, auto-generated port names are also unique
        # at the level of _this_ diagram (subsystems can have ports with the same name)
        name = f"{port.system.name}_{port.name}"
    elif name in self._diagram_input_indices:
        raise BuilderError(
            f"Input port name {name} is not unique",
            system=port.system,
            port_index=port.index,
            port_direction="in",
        )

    # Index at the diagram (not subsystem) level
    port_index = len(self._input_port_ids)
    self._input_port_ids.append(port.locator)
    self._input_port_names.append(name)

    self._diagram_input_indices[name] = port_index

    return port_index

export_output(port, name=None)

Export an output port of a child system as a diagram-level output.

The output port must belong to a system that has already been added to the diagram.

Parameters:

Name Type Description Default
port OutputPort

The output port to export.

required
name str

The name to assign to the exported output port. If not provided, a unique name will be generated.

None

Returns:

Name Type Description
int int

The index (in the to-be-built diagram) of the exported output port.

Raises:

Type Description
BuilderError

If the diagram has already been built.

BuilderError

If the system is not registered.

BuilderError

If the output port name is not unique.

Source code in collimator/framework/diagram_builder.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def export_output(self, port: OutputPort, name: str = None) -> int:
    """Export an output port of a child system as a diagram-level output.

    The output port must belong to a system that has already been added to the
    diagram.

    Args:
        port (OutputPort): The output port to export.
        name (str, optional):
            The name to assign to the exported output port. If not provided, a
            unique name will be generated.

    Returns:
        int: The index (in the to-be-built diagram) of the exported output port.

    Raises:
        BuilderError: If the diagram has already been built.
        BuilderError: If the system is not registered.
        BuilderError: If the output port name is not unique.
    """
    self._check_not_already_built()
    self._check_system_is_registered(port.system)

    if name is None:
        # Since the system names are unique, auto-generated port names are also unique
        # at the level of _this_ diagram (subsystems can have ports with the same name)
        name = f"{port.system.name}_{port.name}"
    elif name in self._output_port_names:
        raise BuilderError(
            f"Output port name {name} is not unique",
            system=port.system,
            port_index=port.index,
            port_direction="out",
        )

    # Index at the diagram (not subsystem) level
    port_index = len(self._output_port_ids)
    self._output_port_ids.append(port.locator)
    self._output_port_names.append(name)

    return port_index

DiagramContext dataclass

Bases: ContextBase

Source code in collimator/framework/context.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
@dataclasses.dataclass(frozen=True)
class DiagramContext(ContextBase):
    subcontexts: OrderedDict[Hashable, LeafContext] = dataclasses.field(
        default_factory=OrderedDict
    )

    def _check_key(self, key: Hashable) -> None:
        assert key == self.owning_system.system_id or key in self.subcontexts, (
            f"System ID {key} not found in DiagramContext {self}.\nIf this ID "
            "references an intermediate diagram, note that intermediate diagrams do "
            "not have associated contexts. Only the root diagram and leaf systems have "
            "contexts."
        )

    def __getitem__(self, key: Hashable) -> LeafContext:
        self._check_key(key)
        if key == self.owning_system.system_id:
            return self
        return self.subcontexts[key]

    def find_context_with_path(self, path: list[str]) -> ContextBase:
        system = self.owning_system.find_system_with_path(path)
        if system is None:
            raise ValueError(
                f"No system with path {path} found in {self.owning_system}"
            )
        return self[system.system_id]

    def with_subcontext(self, key: Hashable, ctx: LeafContext) -> DiagramContext:
        self._check_key(key)
        subcontexts = self.subcontexts.copy()
        subcontexts[key] = ctx
        return dataclasses.replace(self, subcontexts=subcontexts)

    #
    # Simulation interface
    #
    @property
    def state(self) -> Mapping[Hashable, LeafState]:
        return OrderedDict(
            {system_id: subctx.state for system_id, subctx in self.subcontexts.items()}
        )

    @property
    def continuous_subcontexts(self) -> List[LeafContext]:
        return [
            subctx
            for subctx in self.subcontexts.values()
            if subctx.has_continuous_state
        ]

    @property
    def continuous_state(self) -> List[Array]:
        return [subctx.continuous_state for subctx in self.continuous_subcontexts]

    def with_continuous_state(self, sub_xcs: List[Array]) -> DiagramContext:
        # Shallow copy the subcontexts - only modify the ones that have continuous states
        new_subcontexts = self.subcontexts.copy()
        for subctx, sub_xc in zip(self.continuous_subcontexts, sub_xcs):
            new_subcontexts[subctx.system_id] = subctx.with_continuous_state(sub_xc)
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    @property
    def num_continuous_states(self) -> int:
        return sum(
            [subctx.num_continuous_states for subctx in self.subcontexts.values()]
        )

    @property
    def has_continuous_state(self) -> bool:
        return self.num_continuous_states > 0

    @property
    def discrete_subcontexts(self) -> List[LeafContext]:
        return [
            subctx for subctx in self.subcontexts.values() if subctx.has_discrete_state
        ]

    @property
    def discrete_state(self) -> List[List[Array]]:
        return [subctx.discrete_state for subctx in self.discrete_subcontexts]

    def with_discrete_state(self, sub_xds: List[List[Array]]) -> DiagramContext:
        # Shallow copy the subcontexts - only modify the ones that have discrete states
        new_subcontexts = self.subcontexts.copy()
        for subctx, sub_xd in zip(self.discrete_subcontexts, sub_xds):
            new_subcontexts[subctx.system_id] = subctx.with_discrete_state(sub_xd)
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    @property
    def num_discrete_states(self) -> int:
        return sum([subctx.num_discrete_states for subctx in self.subcontexts.values()])

    @property
    def has_discrete_state(self) -> bool:
        return self.num_discrete_states > 0

    @property
    def mode_subcontexts(self) -> List[LeafContext]:
        return [subctx for subctx in self.subcontexts.values() if subctx.has_mode]

    @property
    def mode(self) -> List[int]:
        return [subctx.mode for subctx in self.mode_subcontexts]

    def with_mode(self, sub_modes: List[int]) -> DiagramContext:
        new_subcontexts = self.subcontexts.copy()
        for subctx, sub_mode in zip(self.mode_subcontexts, sub_modes):
            new_subcontexts[subctx.system_id] = subctx.with_mode(sub_mode)
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    @property
    def has_mode(self) -> bool:
        return any([subctx.has_mode for subctx in self.subcontexts.values()])

    def with_state(self, sub_states: Mapping[Hashable, LeafState]) -> DiagramContext:
        new_subcontexts = OrderedDict()
        for system_id, sub_state in sub_states.items():
            new_subcontexts[system_id] = dataclasses.replace(
                self.subcontexts[system_id], state=sub_state
            )
        return dataclasses.replace(self, subcontexts=new_subcontexts)

    def _compute_params(self):
        new_subcontexts = self.subcontexts.copy()
        for system_id, subctx in self.subcontexts.items():
            parameters = subctx.parameters
            for name, param in subctx.owning_system.default_parameters.items():
                parameters[name] = param.get()
        return new_subcontexts

    def with_parameter(self, name: str, value: ArrayLike) -> ContextBase:
        """Create a copy of this context, replacing the specified parameter."""
        self._replace_param(name, value)
        new_subcontexts = self._compute_params()
        return dataclasses.replace(
            self, parameters=self.parameters, subcontexts=new_subcontexts)

    def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
        """Create a copy of this context, replacing only the specified parameters."""
        for name, value in new_parameters.items():
            self._replace_param(name, value)
        new_subcontexts = self._compute_params()
        return dataclasses.replace(
            self, parameters=self.parameters, subcontexts=new_subcontexts)

with_parameter(name, value)

Create a copy of this context, replacing the specified parameter.

Source code in collimator/framework/context.py
443
444
445
446
447
448
def with_parameter(self, name: str, value: ArrayLike) -> ContextBase:
    """Create a copy of this context, replacing the specified parameter."""
    self._replace_param(name, value)
    new_subcontexts = self._compute_params()
    return dataclasses.replace(
        self, parameters=self.parameters, subcontexts=new_subcontexts)

with_parameters(new_parameters)

Create a copy of this context, replacing only the specified parameters.

Source code in collimator/framework/context.py
450
451
452
453
454
455
456
def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
    """Create a copy of this context, replacing only the specified parameters."""
    for name, value in new_parameters.items():
        self._replace_param(name, value)
    new_subcontexts = self._compute_params()
    return dataclasses.replace(
        self, parameters=self.parameters, subcontexts=new_subcontexts)

DiscreteUpdateEvent dataclass

Bases: Event

Event representing a discrete update in a hybrid system.

Source code in collimator/framework/event.py
285
286
287
288
289
290
291
292
293
294
295
296
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class DiscreteUpdateEvent(Event):
    """Event representing a discrete update in a hybrid system."""

    # Supersede type hints in Event with the specific signature for discrete updates
    callback: Callable[[ContextBase], Array] = None
    passthrough: Callable[[ContextBase], Array] = None

    # Inherits docstring from Event. This is only needed to specialize type hints.
    def handle(self, context: ContextBase) -> Array:
        return super().handle(context)

DtypeMismatchError

Bases: StaticError

Block parameters or input/outputs have mismatched dtypes.

Source code in collimator/framework/error.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class DtypeMismatchError(StaticError):
    """Block parameters or input/outputs have mismatched dtypes."""

    def __init__(self, expected_dtype=None, actual_dtype=None, **kwargs):
        super().__init__(**kwargs)
        self.expected_dtype = expected_dtype
        self.actual_dtype = actual_dtype

    def __str__(self):
        if self.expected_dtype or self.actual_dtype:
            return (
                f"Data type mismatch: "
                f"expected {self.expected_dtype}, got {self.actual_dtype}"
                + self._context_info()
            )
        return f"Dtype mismatch{self._context_info()}"

ErrorCollector

Tool used to collect errors related to users model specification. Errors related to user model specification are identified during model static analysis, e.g. context creation, type checking, etc.

An instance of this tool can be created, and then passed down a tree of function call calls to collect errors found any where in the tree. Locally in the tree it can be determined whether it is ok to continue or not. This tool enables collecting errors up until the point when continuation is no longer possible.

Note: this latter behavior, where sometimes there is early exit desired, and all other "pipeline" operations are "nullified", might better be implemented using pymonad:Either class.

Source code in collimator/framework/error.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class ErrorCollector:
    """
    Tool used to collect errors related to users model specification.
    Errors related to user model specification are identified during
    model static analysis, e.g. context creation, type checking, etc.

    An instance of this tool can be created, and then passed down a
    tree of function call calls to collect errors found any where in
    the tree. Locally in the tree it can be determined whether it is
    ok to continue or not. This tool enables collecting errors up until
    the point when continuation is no longer possible.

    Note: this latter behavior, where sometimes there is early exit desired,
    and all other "pipeline" operations are "nullified", might better be
    implemented using pymonad:Either class.
    """

    def __init__(self):
        self._disable_collection = False
        self._parent: Optional["ErrorCollector"] = None
        self.errors: list[BaseException] = []

    def add_error(self, error: BaseException):
        """Add an error to the collection."""

        if self._parent is not None:
            self._parent.add_error(error)
            return

        if not self._disable_collection:
            self.errors.append(error)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # Return values: True to suppress the exception, False to propagate it

        if exc_type is not None:
            if self._parent is not None:
                self._parent.add_error(exc_value)
                return True

            self.add_error(exc_value)
            return False

        return True

    @classmethod
    def context(cls, parent: "ErrorCollector" = None) -> "ErrorCollector":
        """A context manager convenience to use when tracing errors.

        Use as:
        ```
        with ErrorCollector.trace(error_context) as ec:
            ...
        ```

        If the parent context is None, then exceptions will pass through without
        being collected. Else, exceptions will be collected in the parent context.
        """

        if parent is None:
            ctx = cls()
            ctx._disable_collection = True
            return ctx

        ctx = cls()
        ctx._parent = parent
        return ctx

add_error(error)

Add an error to the collection.

Source code in collimator/framework/error.py
313
314
315
316
317
318
319
320
321
def add_error(self, error: BaseException):
    """Add an error to the collection."""

    if self._parent is not None:
        self._parent.add_error(error)
        return

    if not self._disable_collection:
        self.errors.append(error)

context(parent=None) classmethod

A context manager convenience to use when tracing errors.

Use as:

with ErrorCollector.trace(error_context) as ec:
    ...

If the parent context is None, then exceptions will pass through without being collected. Else, exceptions will be collected in the parent context.

Source code in collimator/framework/error.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
@classmethod
def context(cls, parent: "ErrorCollector" = None) -> "ErrorCollector":
    """A context manager convenience to use when tracing errors.

    Use as:
    ```
    with ErrorCollector.trace(error_context) as ec:
        ...
    ```

    If the parent context is None, then exceptions will pass through without
    being collected. Else, exceptions will be collected in the parent context.
    """

    if parent is None:
        ctx = cls()
        ctx._disable_collection = True
        return ctx

    ctx = cls()
    ctx._parent = parent
    return ctx

EventCollection

A collection of events owned by a system.

Users should not need to interact with these objects directly. They are intended to be used internally by the simulation framework for handling events in hybrid system simulation.

These contain callback functions that update the context in various ways when the event is triggered. There will be different "collections" for each trigger type in simulation (e.g. periodic vs zero-crossing). Within the collections, events are broken out by function (e.g. discrete vs unrestricted updates).

There are separate implementations for leaf and diagram systems, where the DiagramCEventCollection preserves the tree structure of the underlying Diagram. However, the interface in both cases is the same and is identical to the interface defined by EventCollection.

Source code in collimator/framework/event.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
class EventCollection(metaclass=abc.ABCMeta):
    """A collection of events owned by a system.

    Users should not need to interact with these objects directly. They are intended
    to be used internally by the simulation framework for handling events in hybrid
    system simulation.

    These contain callback functions that update the context in various ways
    when the event is triggered. There will be different "collections" for each
    trigger type in simulation (e.g. periodic vs zero-crossing). Within the
    collections, events are broken out by function (e.g. discrete vs unrestricted
    updates).

    There are separate implementations for leaf and diagram systems, where the
    DiagramCEventCollection preserves the tree structure of the underlying
    Diagram. However, the interface in both cases is the same and is identical to
    the interface defined by EventCollection.
    """

    @abc.abstractmethod
    def __getitem__(self, key: Hashable) -> EventCollection:
        pass

    @property
    @abc.abstractmethod
    def events(self) -> List[Event]:
        pass

    @property
    @abc.abstractmethod
    def num_events(self) -> int:
        pass

    @property
    def has_events(self) -> bool:
        return self.num_events > 0

    def __iter__(self):
        return iter(self.events)

    def __len__(self):
        return self.num_events

    @abc.abstractmethod
    def activate(self, activation_fn) -> EventCollection:
        pass

    def mark_all_active(self) -> EventCollection:
        return self.activate(lambda _: True)

    def mark_all_inactive(self) -> EventCollection:
        return self.activate(lambda _: False)

    @property
    def num_active(self) -> int:
        def _get_active(event_data: EventData) -> bool:
            return event_data.active

        active_tree = tree_util.tree_map(
            _get_active,
            self,
            is_leaf=is_event_data,
        )
        return sum(tree_util.tree_leaves(active_tree))

    @property
    def has_active(self) -> bool:
        return self.num_active > 0

    @property
    def has_triggered(self) -> bool:
        def _get_triggered(event_data: EventData) -> bool:
            return event_data.active & event_data.triggered

        triggered_tree = tree_util.tree_map(
            _get_triggered,
            self,
            is_leaf=is_event_data,
        )
        return sum(tree_util.tree_leaves(triggered_tree)) > 0

    @property
    @abc.abstractmethod
    def terminal_events(self) -> EventCollection:
        pass

    @property
    def has_terminal_events(self):
        return self.terminal_events.has_events

    @property
    def has_active_terminal(self) -> bool:
        return self.terminal_events.has_triggered

    def pprint(self, output=print):
        output(self._pprint_helper().strip())

    def _pprint_helper(self, prefix="") -> str:
        s = f"{prefix}|-- \n"
        if len(self.events) > 0:
            s += f"{prefix}    Events:\n"
            for event in self.events:
                s += f"{prefix}    |  {event}\n"
        return s

    def __repr__(self) -> str:
        s = f"{type(self).__name__}("
        if self.has_events:
            s += f"discrete_update: {self.events} "
        s += ")"

        return s

IntegerTime

Class for managing conversion between decimal and integer time.

Source code in collimator/framework/event.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class IntegerTime:
    """Class for managing conversion between decimal and integer time."""

    # TODO: Can we use this directly as an int?  Would need to implement __add__,
    # __sub__, etc.  Also, comparisons and floor divide.  Would make the code in
    # Simulator cleaner, but dealing with JAX tracers in `where` and the like
    # might be difficult.  See commit 043c8f757 for a previous attempt.

    #
    # Class variables
    #
    time_scale = DEFAULT_TIME_SCALE  # int -> float conversion factor
    inv_time_scale = 1 / time_scale  # float -> int conversion factor

    # Type of the integer time representation. Defaults to x64 unless explicitly disabled.
    dtype: DTypeLike = cnp.intx

    # Largest time value representable by IntegerTime.dtype
    max_int_time = cnp.iinfo(dtype).max

    # Floating point representation of max_int_time
    max_float_time = cnp.asarray(max_int_time * time_scale, dtype=dtype)

    #
    # Class methods
    #
    @classmethod
    def set_scale(cls, time_scale: float):
        cls.time_scale = time_scale
        cls.inv_time_scale = 1 / time_scale
        cls.max_float_time = cnp.asarray(cls.max_int_time * time_scale, dtype=cls.dtype)

    @classmethod
    def set_default_scale(cls):
        cls.set_scale(DEFAULT_TIME_SCALE)

    @classmethod
    def from_decimal(cls, time: float) -> int:
        """Convert a floating-point time to an integer time."""
        # First limit to the max value to avoid overflow with inf or very large values.
        time = cnp.minimum(time, cls.max_float_time)
        return cnp.asarray(time * cls.inv_time_scale, dtype=cls.dtype)

    @classmethod
    def as_decimal(cls, time: int) -> float:
        """Convert an integer time to a floating-point time."""
        return time * cls.time_scale

as_decimal(time) classmethod

Convert an integer time to a floating-point time.

Source code in collimator/framework/event.py
89
90
91
92
@classmethod
def as_decimal(cls, time: int) -> float:
    """Convert an integer time to a floating-point time."""
    return time * cls.time_scale

from_decimal(time) classmethod

Convert a floating-point time to an integer time.

Source code in collimator/framework/event.py
82
83
84
85
86
87
@classmethod
def from_decimal(cls, time: float) -> int:
    """Convert a floating-point time to an integer time."""
    # First limit to the max value to avoid overflow with inf or very large values.
    time = cnp.minimum(time, cls.max_float_time)
    return cnp.asarray(time * cls.inv_time_scale, dtype=cls.dtype)

LeafContext dataclass

Bases: ContextBase

Source code in collimator/framework/context.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
@dataclasses.dataclass(frozen=True)
class LeafContext(ContextBase):
    state: LeafState = None

    @property
    def system_id(self) -> Hashable:
        return self.owning_system.system_id

    def __getitem__(self, key: Hashable) -> LeafContext:
        """Dummy indexing for compatibility with DiagramContexts, returning self."""
        assert key == self.system_id, f"Attempting to get subcontext {key} from {self}"
        return self

    def with_subcontext(self, key: Hashable, ctx: LeafContext) -> LeafContext:
        """Dummy replacement for compatibility with DiagramContexts, returning ctx."""
        assert (
            key == self.system_id
        ), f"System ID {key} does not match leaf ID {self.system_id}"
        assert (
            key == ctx.system_id
        ), f"System ID {key} does not match leaf ID {ctx.system_id}"
        return ctx

    def __repr__(self) -> str:
        return f"{type(self).__name__}(sys={self.system_id})"

    def with_state(self, state: LeafState) -> LeafContext:
        return dataclasses.replace(self, state=state)

    @property
    def continuous_state(self) -> LeafStateComponent:
        return self.state.continuous_state

    def with_continuous_state(self, value: LeafStateComponent) -> LeafContext:
        return dataclasses.replace(self, state=self.state.with_continuous_state(value))

    @property
    def num_continuous_states(self) -> int:
        return self.state.num_continuous_states

    @property
    def has_continuous_state(self) -> bool:
        return self.state.has_continuous_state

    @property
    def discrete_state(self) -> LeafStateComponent:
        return self.state.discrete_state

    def with_discrete_state(self, value: LeafStateComponent) -> LeafContext:
        return dataclasses.replace(self, state=self.state.with_discrete_state(value))

    @property
    def num_discrete_states(self) -> int:
        return self.state.num_discrete_states

    @property
    def has_discrete_state(self) -> bool:
        return self.state.has_discrete_state

    @property
    def mode(self) -> int:
        return self.state.mode

    @property
    def has_mode(self) -> bool:
        return self.state.has_mode

    def with_mode(self, value: int) -> LeafContext:
        return dataclasses.replace(self, state=self.state.with_mode(value))

    @property
    def cache(self) -> tuple[Array]:
        return self.state.cache

    @property
    def num_cached_values(self) -> int:
        return self.state.num_cached_values

    @property
    def has_cache(self) -> bool:
        return self.state.has_cache

    def with_cached_value(self, index: int, value: Array) -> LeafContext:
        return dataclasses.replace(
            self, state=self.state.with_cached_value(index, value)
        )

    def with_parameter(self, name: str, value: ArrayLike) -> ContextBase:
        """Create a copy of this context, replacing the specified parameter."""
        self._replace_param(name, value)
        return dataclasses.replace(self, parameters=self.parameters)

    def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
        """Create a copy of this context, replacing only the specified parameters."""
        for name, value in new_parameters.items():
            self._replace_param(name, value)
        return dataclasses.replace(self, parameters=self.parameters)

__getitem__(key)

Dummy indexing for compatibility with DiagramContexts, returning self.

Source code in collimator/framework/context.py
221
222
223
224
def __getitem__(self, key: Hashable) -> LeafContext:
    """Dummy indexing for compatibility with DiagramContexts, returning self."""
    assert key == self.system_id, f"Attempting to get subcontext {key} from {self}"
    return self

with_parameter(name, value)

Create a copy of this context, replacing the specified parameter.

Source code in collimator/framework/context.py
300
301
302
303
def with_parameter(self, name: str, value: ArrayLike) -> ContextBase:
    """Create a copy of this context, replacing the specified parameter."""
    self._replace_param(name, value)
    return dataclasses.replace(self, parameters=self.parameters)

with_parameters(new_parameters)

Create a copy of this context, replacing only the specified parameters.

Source code in collimator/framework/context.py
305
306
307
308
309
def with_parameters(self, new_parameters: Mapping[str, ArrayLike]) -> ContextBase:
    """Create a copy of this context, replacing only the specified parameters."""
    for name, value in new_parameters.items():
        self._replace_param(name, value)
    return dataclasses.replace(self, parameters=self.parameters)

with_subcontext(key, ctx)

Dummy replacement for compatibility with DiagramContexts, returning ctx.

Source code in collimator/framework/context.py
226
227
228
229
230
231
232
233
234
def with_subcontext(self, key: Hashable, ctx: LeafContext) -> LeafContext:
    """Dummy replacement for compatibility with DiagramContexts, returning ctx."""
    assert (
        key == self.system_id
    ), f"System ID {key} does not match leaf ID {self.system_id}"
    assert (
        key == ctx.system_id
    ), f"System ID {key} does not match leaf ID {ctx.system_id}"
    return ctx

LeafState dataclass

Container for state information for a leaf system.

Attributes:

Name Type Description
name str

Name of the leaf system that owns this state.

continuous_state LeafStateComponent

Continuous state of the system, i.e. the component of state that evolves in continuous time. If the system has no continuous state, this will be None.

discrete_state LeafStateComponent

Discrete state of the system, i.e. one or more components of state that do not change continuously with ime (not necessarily discrete-valued). If the system has no discrete state, this will be None.

mode int

An integer value indicating the current "mode", "stage", or discrete-valued state component of the system. Used for finite state machines or multi-stage hybrid systems. If the system has no mode, this will be None.

cache tuple[LeafStateComponent]

The current values of sample-and-hold outputs from the system. In a pure discrete system these would not be state components (just results of feedthrough computations), but in a hybrid or multirate system they act as discrete state from the perspective of continuous or asynchronous discrete components of the system. Hence, they are stored in the state, but are maintained separately from the normal internal state of the system.

Notes

(1) This class is immutable. To modify a LeafState, use the with_* methods.

(2) The type annotations for state components are LeafStateComponent, which is a union of array, tuple, and named tuple. The most common case is arrays, but this allows for more flexibility in defining state components, e.g. a second-order system can define a named tuple of generalized coordinates and velocities rather than concatenating into a single array.

Source code in collimator/framework/state.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@dataclasses.dataclass(frozen=True)
class LeafState:
    """Container for state information for a leaf system.

    Attributes:
        name (str):
            Name of the leaf system that owns this state.
        continuous_state (LeafStateComponent):
            Continuous state of the system, i.e. the component of state that evolves in
            continuous time. If the system has no continuous state, this will be None.
        discrete_state (LeafStateComponent):
            Discrete state of the system, i.e. one or more components of state that do
            not change continuously with ime (not necessarily discrete-_valued_). If
            the system has no discrete state, this will be None.
        mode (int):
            An integer value indicating the current "mode", "stage", or discrete-valued
            state component of the system.  Used for finite state machines or
            multi-stage hybrid systems.  If the system has no mode, this will be None.
        cache (tuple[LeafStateComponent]):
            The current values of sample-and-hold outputs from the system.  In a pure
            discrete system these would not be state components (just results of
            feedthrough computations), but in a hybrid or multirate system they act as
            discrete state from the perspective of continuous or asynchronous discrete
            components of the system.  Hence, they are stored in the state, but are
            maintained separately from the normal internal state of the system.

    Notes:
        (1) This class is immutable.  To modify a LeafState, use the `with_*` methods.

        (2) The type annotations for state components are LeafStateComponent, which is
        a union of array, tuple, and named tuple. The most common case is arrays, but
        this allows for more flexibility in defining state components, e.g. a
        second-order system can define a named tuple of generalized coordinates and
        velocities rather than concatenating into a single array.
    """

    name: str = None
    continuous_state: LeafStateComponent = None
    discrete_state: LeafStateComponent = None
    mode: int = None
    cache: tuple[Array] = None

    def __repr__(self) -> str:
        states = []
        if self.continuous_state is not None:
            states.append(f"xc={self.continuous_state}")
        if self.discrete_state is not None:
            states.append(f"xd={self.discrete_state}")
        if self.mode is not None:
            states.append(f"s={self.mode}")
        return f"{type(self).__name__}({', '.join(states)})"

    def with_continuous_state(self, value: LeafStateComponent) -> LeafState:
        """Create a copy of this LeafState with the continuous state replaced."""
        if value is not None and self.continuous_state is not None:
            value = tree_util.tree_map(self._reshape_like, value, self.continuous_state)

        return dataclasses.replace(self, continuous_state=value)

    def _component_size(self, component: LeafStateComponent) -> int:
        if component is None:
            return 0
        if isinstance(component, tuple):
            # return sum(x.size for x in component)
            return len(component)
        return component.size

    def _reshape_like(self, new_value: Array, current_value: Array) -> Array:
        """Helper function for tree-mapped type conversions.

        Ensures that the new components are array-like and have the same shape as
        the existing state to preserve PyTree structure.
        """
        return reshape(new_value, current_value.shape)

    @property
    def num_continuous_states(self) -> int:
        return self._component_size(self.continuous_state)

    @property
    def has_continuous_state(self) -> bool:
        return self.num_continuous_states > 0

    def with_discrete_state(self, value: LeafStateComponent) -> LeafState:
        """Create a copy of this LeafState with the discrete state replaced."""
        if value is not None and self.discrete_state is not None:
            value = tree_util.tree_map(self._reshape_like, value, self.discrete_state)

        return dataclasses.replace(self, discrete_state=value)

    @property
    def num_discrete_states(self) -> int:
        return self._component_size(self.discrete_state)

    @property
    def has_discrete_state(self) -> bool:
        return self.num_discrete_states > 0

    def with_mode(self, value: int) -> LeafState:
        """Create a copy of this LeafState with the mode replaced."""
        return dataclasses.replace(self, mode=value)

    @property
    def has_mode(self) -> bool:
        return self.mode is not None

    def with_cached_value(self, index: int, value: Array) -> LeafState:
        """Create a copy of this LeafState with the specified cache value replaced."""
        cache = list(self.cache)
        cache[index] = value
        return dataclasses.replace(self, cache=tuple(cache))

    def has_cache(self) -> bool:
        return self.cache is not None

    def num_cached_values(self) -> int:
        return len(self.cache)

with_cached_value(index, value)

Create a copy of this LeafState with the specified cache value replaced.

Source code in collimator/framework/state.py
156
157
158
159
160
def with_cached_value(self, index: int, value: Array) -> LeafState:
    """Create a copy of this LeafState with the specified cache value replaced."""
    cache = list(self.cache)
    cache[index] = value
    return dataclasses.replace(self, cache=tuple(cache))

with_continuous_state(value)

Create a copy of this LeafState with the continuous state replaced.

Source code in collimator/framework/state.py
102
103
104
105
106
107
def with_continuous_state(self, value: LeafStateComponent) -> LeafState:
    """Create a copy of this LeafState with the continuous state replaced."""
    if value is not None and self.continuous_state is not None:
        value = tree_util.tree_map(self._reshape_like, value, self.continuous_state)

    return dataclasses.replace(self, continuous_state=value)

with_discrete_state(value)

Create a copy of this LeafState with the discrete state replaced.

Source code in collimator/framework/state.py
133
134
135
136
137
138
def with_discrete_state(self, value: LeafStateComponent) -> LeafState:
    """Create a copy of this LeafState with the discrete state replaced."""
    if value is not None and self.discrete_state is not None:
        value = tree_util.tree_map(self._reshape_like, value, self.discrete_state)

    return dataclasses.replace(self, discrete_state=value)

with_mode(value)

Create a copy of this LeafState with the mode replaced.

Source code in collimator/framework/state.py
148
149
150
def with_mode(self, value: int) -> LeafState:
    """Create a copy of this LeafState with the mode replaced."""
    return dataclasses.replace(self, mode=value)

LeafSystem dataclass

Bases: SystemBase

Basic building block for dynamical systems.

A LeafSystem is a minimal component of a system model in collimator, containing no subsystems. Inputs, outputs, state, parameters, updates, etc. can be added to the block using the various declare_* methods. The built-in blocks in collimator.library are all subclasses of LeafSystem, as are any custom blocks defined by the user.

Source code in collimator/framework/leaf_system.py
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
class LeafSystem(SystemBase):
    """Basic building block for dynamical systems.

    A LeafSystem is a minimal component of a system model in collimator, containing no
    subsystems.  Inputs, outputs, state, parameters, updates, etc. can be added to the
    block using the various `declare_*` methods.  The built-in blocks in
    collimator.library are all subclasses of LeafSystem, as are any custom blocks defined
    by the user."""

    # SystemBase is a dataclass, so we need to call __post_init__ explicitly
    def __post_init__(self):
        super().__post_init__()
        logger.debug(f"Initializing {self.name} [{self.system_id}]")

        # If not None, this defines the shape and data type of the continuous state
        # component.  This value will be used to initialize the context, so it will
        # also serve as the initial value unless explicitly overridden. It will
        # typically be an array, but it can be any PyTree-structured object (list,
        # dict, namedtuple, etc.), provided that the ODE function returns a PyTree
        # of the same structure.
        self._default_continuous_state: LeafStateComponent = None

        # The SystemCallback associated with time derivatives of the continuous state.
        # This is initialized in the `declare_continuous_state` method.
        self.ode_callback: SystemCallback = None

        # If not empty, this defines the shape and data type of the discrete state.
        # This value will be used to initialize the context, so it will also serve
        # as the initial value unless explicitly overridden. This will often be an
        # array, but as for the continuous state it can be any PyTree-structured
        # object (list, dict, namedtuple, etc.), provided that the update functions
        # return a PyTree of the same structure.
        self._default_discrete_state: LeafStateComponent = None

        # If not None, the system has a "mode" or "stage" component of the state.
        # In a "state machine" paradigm, this represents the current state of the
        # system (although "state" is obviously used for other things in this case).
        # The mode is an integer value, and the system can declare transitions between
        # modes using the `declare_zero_crossing` method, which in addition to the
        # guard function and reset map also takes optional `start_mode` and `end_mode`
        # arguments.
        self._default_mode: int = None

        # Set of "template" values for the sample-and-hold output ports, if known.
        # If not known, these will be `None`, in which case an appropriate value is
        # inferred from upstream during static analysis.
        self._default_cache: List[LeafStateComponent] = []

        # Transition map from (start_mode -> [*end_modes]) indicating which
        # transition events are active in each mode.  This is not used by
        # any logic in the system, but can be useful for debugging.
        self.transition_map: dict[int, List[Tuple[int, ZeroCrossingEvent]]] = {}

        # Set of events that updates at a fixed rate.  Each event has its own period
        # and offset, so "fires" independently of the other events. These can be
        # created using the `declare_periodic_update` method.
        self._state_update_events: List[DiscreteUpdateEvent] = []

        # Set of events that update when a zero-crossing occurs.  Each event has its
        # own guard function and, optionally, reset map, start mode, and end mode.
        # These can be created using the `declare_zero_crossing` method.
        self._zero_crossing_events: List[ZeroCrossingEvent] = []

    @property
    def has_feedthrough_side_effects(self) -> bool:
        # See explanation in `SystemBase.has_feedthrough_side_effects`.  This will
        # almost always be False, but can be overridden in special cases where a
        # feedthrough output is computed via use of `io_callback`.
        return False

    @property
    def has_ode_side_effects(self) -> bool:
        # This will almost always be False for a LeafSystem - Diagram systems
        # have some special logic to do this determination.
        return False

    @property
    def has_continuous_state(self) -> bool:
        return self._default_continuous_state is not None

    @property
    def has_discrete_state(self) -> bool:
        return self._default_discrete_state is not None

    @property
    def has_zero_crossing_events(self) -> bool:
        return len(self._zero_crossing_events) > 0

    #
    # Event handling
    #
    def wrap_callback(
        self, callback: Callable, collect_inputs: bool = True
    ) -> Callable:
        """Wrap an update function to unpack local variables and block inputs.

        The callback should have the signature
        `callback(time, state, *inputs, **params) -> result`
        and will be wrapped to have the signature `callback(context) -> result`,
        as expected by the event handling logic.

        This is used internally for declaration methods like
        `declare_periodic_update` so that users can write more intuitive
        block-level update functions without worrying about the "context", and have
        them automatically wrapped to have the right interface.  It can also be
        called directly by users to wrap their own update functions, for example to
        create a callback function for `declare_output_port`.

        The context and state are strictly immutable, so the callback should not
        attempt to change any values in the context or state.  Even in cases where
        it is impossible to _enforce_ this (e.g. a state component is a list, which
        is always mutable in Python), the callback should be careful to avoid direct
        modification of the context or state, which may lead to unexpected behavior
        or JAX tracer errors.

        Args:
            callback (Callable):
                The (pure) function to be wrapped. See above for expected signature.
            collect_inputs (bool):
                If True, the callback will eval input ports to gather input values.
                Normally this should be True, but it can be set to False if the
                return value depends only on the state but not inputs, for
                instance. This helps reduce the number of expressions that need to
                be JIT compiled. Default is True.

        Returns:
            Callable:
                The wrapped function, with signature `callback(context) -> result`.
        """

        def _wrapped_callback(context: ContextBase) -> LeafStateComponent:
            if collect_inputs:
                inputs = self.collect_inputs(context)
            else:
                inputs = ()
            leaf_context: LeafContext = context[self.system_id]

            leaf_state = leaf_context.state
            params = leaf_context.parameters
            return callback(context.time, leaf_state, *inputs, **params)

        return _wrapped_callback

    def _passthrough(self, context: ContextBase) -> LeafState:
        """Dummy callback for inactive events."""
        return context[self.system_id].state

    @property
    def state_update_events(self) -> FlatEventCollection:
        return FlatEventCollection(tuple(self._state_update_events))

    @property
    def zero_crossing_events(self) -> LeafEventCollection:
        # The default is for all to be active. Use the `determine_active_guards`
        # method to determine which are active conditioned on the current "mode"
        # or "stage" of the system.
        return LeafEventCollection(tuple(self._zero_crossing_events)).mark_all_active()

    # Inherits docstring from SystemBase
    def eval_zero_crossing_updates(
        self,
        context: ContextBase,
        events: LeafEventCollection,
    ) -> LeafState:
        local_events = events[self.system_id]
        state = context[self.system_id].state

        logger.debug(f"Eval update events for {self.name}")
        logger.debug(f"local events: {local_events}")

        for event in local_events:
            # This is evaluated conditionally on event_data.active
            state = event.handle(context)

            # Store the updated state in the context for this block
            leaf_context = context[self.system_id].with_state(state)

            # Update the context for this block in the overall context
            context = context.with_subcontext(self.system_id, leaf_context)

        # Now `context` contains the updated "plus" state for this block, but
        # this needs to be discarded so that other block updates can also be
        # processed using the "minus" state. This is done by simply returning the
        # "plus" state and discarding the rest of the updated context.
        return state

    # Inherits docstring from SystemBase
    def determine_active_guards(self, root_context: ContextBase) -> LeafEventCollection:
        mode = root_context[self.system_id].mode  # Current system mode

        def _conditionally_activate(
            event: ZeroCrossingEvent,
        ) -> ZeroCrossingEvent:
            # Check to see if the event corresponds to a mode transition
            # If not, just return the event unchanged (will be active)
            if event.active_mode is None:
                return event
            # If the event does correspond to a mode transition, check to see
            # if the event is active in the current mode
            return cond(
                mode == event.active_mode,
                lambda e: e.mark_active(),
                lambda e: e.mark_inactive(),
                event,
            )

        # Apply the conditional activation to all events
        zero_crossing_events = LeafEventCollection(
            tuple(_conditionally_activate(e) for e in self.zero_crossing_events)
        )

        logger.debug(f"Zero-crossing events for {self.name}: {zero_crossing_events}")
        return zero_crossing_events

    @property
    def _flat_callbacks(self) -> List[OutputPort]:
        """Return all of the sample-and-hold output ports in this system."""
        return self.callbacks

    def declare_cache(
        self,
        callback: Callable,
        period: float | Parameter = None,
        offset: float | Parameter = 0.0,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array = None,
        requires_inputs: bool = True,
    ) -> int:
        """Declare a stored computation for the system.

        This method accepts a callback function with the block-level signature
            `callback(time, state, *inputs, **parameters) -> value`
        and wraps it to have the signature
            `callback(context) -> value`

        This callback can optionally be used to define a periodic update event that
        refreshes the cached value.  Other calculations (e.g. sample-and-hold output
        ports) can then depend on the cached value.

        Args:
            callback (Callable):
                The callback function defining the cached computation.
            period (float, optional):
                If not None, the callback function will be used to define a periodic
                update event that refreshes the value. Defaults to None.
            offset (float, optional):
                The offset of the periodic update event. Defaults to 0.0.  Will be ignored
                unless `period` is not None.
            name (str, optional):
                The name of the cached value. Defaults to None.
            default_value (Array, optional):
                The default value of the result, if known. Defaults to None.
            requires_inputs (bool, optional):
                If True, the callback will eval input ports to gather input values.
                This will add a bit to compile time, so setting to False where possible
                is recommended. Defaults to True.
            prerequisites_of_calc (List[DependencyTicket], optional):
                The dependency tickets for the computation. Defaults to None, in which
                case the default is to assume dependency on either (inputs) if
                `requires_inputs` is True, or (nothing) otherwise.

        Returns:
            int: The index of the callback in `system.callbacks`.  The cache index can
                recovered from `system.callbacks[callback_index].cache_index`.
        """

        if isinstance(period, Parameter):
            period = period.get()
        if isinstance(offset, Parameter):
            offset = offset.get()

        # The index in the list of system callbacks
        callback_index = len(self.callbacks)

        # This is the index that this cached value will have in state.cache
        cache_index = len(self._default_cache)
        self._default_cache.append(default_value)

        # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
        # default prereq if the update callback doesn't use them
        if prerequisites_of_calc is None:
            if requires_inputs:
                prerequisites_of_calc = [DependencyTicket.u]
            else:
                prerequisites_of_calc = [DependencyTicket.nothing]

        def _update_callback(
            time: Scalar, state: LeafState, *inputs, **parameters
        ) -> LeafState:
            output = callback(time, state, *inputs, **parameters)
            return state.with_cached_value(cache_index, output)

        _update_callback = self.wrap_callback(
            _update_callback, collect_inputs=requires_inputs
        )

        if period is None:
            event = None

        else:
            # The cache has a periodic event updating its value defined by the callback
            event = DiscreteUpdateEvent(
                system_id=self.system_id,
                event_data=PeriodicEventData(
                    period=period, offset=offset, active=False
                ),
                name=f"{self.name}:cache_update_{cache_index}_",
                callback=_update_callback,
                passthrough=self._passthrough,
            )

        if name is None:
            name = f"cache_{cache_index}"

        sys_callback = SystemCallback(
            callback=_update_callback,
            system=self,
            callback_index=callback_index,
            name=name,
            prerequisites_of_calc=prerequisites_of_calc,
            event=event,
            default_value=default_value,
            cache_index=cache_index,
        )
        self.callbacks.append(sys_callback)

        return callback_index

    def declare_continuous_state(
        self,
        shape: ShapeLike = None,
        default_value: Array = None,
        dtype: DTypeLike = None,
        ode: Callable = None,
        as_array: bool = True,
        requires_inputs: bool = True,
        prerequisites_of_calc: List[DependencyTicket] = None,
    ):
        """Declare a continuous state component for the system.

        The `ode` callback computes the time derivative of the continuous state based on the
        current time, state, and any additional inputs. If `ode` is not provided, a default
        zero vector of the same size as the continuous state is used. If provided, the `ode`
        callback should have the signature `ode(time, state, *inputs, **params) -> xcdot`.

        Args:
            shape (ShapeLike, optional):
                The shape of the continuous state vector. Defaults to None.
            default_value (Array, optional):
                The initial value of the continuous state vector. Defaults to None.
            dtype (DTypeLike, optional):
                The data type of the continuous state vector. Defaults to None.
            ode (Callable, optional):
                The callback for computing the time derivative of the continuous state.
                Should have the signature:
                    `ode(time, state, *inputs, **parameters) -> xcdot`.
                Defaults to None.
            as_array (bool, optional):
                If True, treat the default_value as an array-like (cast if necessary).
                Otherwise, it will be stored as the default state without modification.
            requires_inputs (bool, optional):
                If True, indicates that the ODE computation requires inputs.
            prerequisites_of_calc (List[DependencyTicket], optional):
                The dependency tickets for the ODE computation. Defaults to None, in
                which case the assumption is a dependency on either (time, continuous
                state) if `requires_inputs` is False, otherwise (time, continuous state,
                inputs.

        Raises:
            AssertionError:
                If neither shape nor default_value is provided.

        Notes:
            (1) Only one of `shape` and `default_value` should be provided. If `default_value`
            is provided, it will be used as the initial value of the continuous state. If
            `shape` is provided, the initial value will be a zero vector of the given shape
            and specified dtype.
        """

        if prerequisites_of_calc is None:
            prerequisites_of_calc = [DependencyTicket.time, DependencyTicket.xc]
            if requires_inputs:
                prerequisites_of_calc.append(DependencyTicket.u)

        if as_array:
            default_value = utils.make_array(
                default_value, dtype=dtype, shape=shape
            )

        logger.debug(f"In block {self.name} [{self.system_id}]: {default_value=}")

        # Tree-map the default value to ensure that it is an array-like with the
        # correct shape and dtype. This is necessary because the default value
        # may be a list, tuple, or other PyTree-structured object.
        default_value = tree_util.tree_map(cnp.asarray, default_value)

        self._default_continuous_state = default_value

        if ode is None:
            # If no ODE is specified, return a zero vector of the same size as the
            # continuous state. This will break if the continuous state is
            # a named tuple, in which case a custom ODE must be provided.
            assert as_array, "Must provide custom ODE for non-array continuous state"

            def ode(time, state, *inputs, **parameters):
                return cnp.zeros_like(default_value)

        # Wrap the ode function to accept a context and return the time derivatives.
        ode = self.wrap_callback(ode)

        # Declare the time derivative function as a system callback so that its
        # dependencies can be tracked in the system dependency graph
        self.ode_callback = SystemCallback(
            callback=ode,
            system=self,
            callback_index=len(self.callbacks),
            name=f"{self.name}_ode",
            prerequisites_of_calc=prerequisites_of_calc,
        )

        self.callbacks.append(self.ode_callback)

        # Override the default `eval_time_derivatives` to use the wrapped ODE function
        self.eval_time_derivatives = self.ode_callback.eval

    def declare_discrete_state(
        self,
        shape: ShapeLike = None,
        default_value: Array | Parameter = None,
        dtype: DTypeLike = None,
        as_array: bool = True,
    ):
        """Declare a new discrete state component for the system.

        The discrete state is a component of the system's state that can be updated
        at specific events, such as zero-crossings or periodic updates. Multiple
        discrete states can be declared, and each is associated with a unique index.
        The index is used to access and update the corresponding discrete state in
        the system's context during event handling.

        The declared discrete state is initialized with either the provided default
        value or zeros of the correct shape and dtype.

        Args:
            shape (ShapeLike, optional):
                The shape of the discrete state. Defaults to None.
            default_value (Array, optional):
                The initial value of the discrete state. Defaults to None.
            dtype (DTypeLike, optional):
                The data type of the discrete state. Defaults to None.
            as_array (bool, optional):
                If True, treat the default_value as an array-like (cast if necessary).
                Otherwise, it will be stored as the default state without modification.

        Raises:
            AssertionError:
                If as_array is True and neither shape nor default_value is provided.

        Notes:
            (1) Only one of `shape` and `default_value` should be provided. If
            `default_value` is provided, it will be used as the initial value of the
            continuous state. If `shape` is provided, the initial value will be a
            zero vector of the given shape and specified dtype.

            (2) Use `declare_periodic_update` to declare an update event that
            modifies the discrete state at a recurring interval.
        """

        if isinstance(default_value, Parameter):
            default_value = default_value.get()

        if as_array:
            default_value = utils.make_array(
                default_value, dtype=dtype, shape=shape
            )

        # Tree-map the default value to ensure that it is an array-like with the
        # correct shape and dtype. This is necessary because the default value
        # may be a list, tuple, or other PyTree-structured object.
        default_value = tree_util.tree_map(cnp.asarray, default_value)

        self._default_discrete_state = default_value

    def declare_configuration_parameters(self, **params):
        """Declare a set of "configuration" parameters for the system.

        These parameters are non-numeric parameters used for block configuration.
        Their declaration as parameters rather than object attributes is mainly
        for the purpose of serialization - blocks that take boolean or string
        parameters can register them as configuration parameters and they will be
        properly serialized.

        The args should be a dict of name-value pairs, where the values are either
        strings, bool, arrays, or Parameters.

        Typical usage:

        ```python
        class MyBlock(LeafSystem):
            def __init__(self, param1=True, param2=1.0):
                super().__init__()
                self.declare_configuration_parameters(param1=param1, param2=param2)
        ```
        """
        for name, value in params.items():
            if isinstance(value, list):
                self._instance_parameters[name] = Parameter(value=np.array(value))
            else:
                self._instance_parameters[name] = Parameter(value=value)

    #
    # I/O declaration
    #
    def declare_output_port(
        self,
        callback: Callable,
        period: float | Parameter = None,
        offset: float | Parameter = 0.0,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array | Parameter = None,
        requires_inputs: bool = True,
    ) -> int:
        """Declare an output port in the LeafSystem.

        This method accepts a callback function with the block-level signature
            `callback(time, state, *inputs, **parameters) -> value`
        and wraps it to the signature expected by SystemBase.declare_output_port:
            `callback(context) -> value`

        Args:
            callback (Callable):
                The callback function defining the output port.
            period (float, optional):
                If not None, the port will act as a "sample-and-hold", with the
                callback function used to define a periodic update event that refreshes
                the value that will be returned by the port. Typically this should
                match the update period of some associated update event in the system.
                Defaults to None.
            offset (float, optional):
                The offset of the periodic update event. Defaults to 0.0.  Will be ignored
                unless `period` is not None.
            name (str, optional):
                The name of the output port. Defaults to None.
            default_value (Array, optional):
                The default value of the output port, if known. Defaults to None.
            requires_inputs (bool, optional):
                If True, the callback will eval input ports to gather input values.
                This will add a bit to compile time, so setting to False where possible
                is recommended. Defaults to True.
            prerequisites_of_calc (List[DependencyTicket], optional):
                The dependency tickets for the output port computation.  Defaults to
                None, in which case the assumption is a dependency on either (nothing)
                if `requires_inputs` is False otherwise (inputs).

        Returns:
            int: The index of the declared output port.
        """

        if isinstance(default_value, Parameter):
            default_value = default_value.get()
        if isinstance(period, Parameter):
            period = period.get()
        if isinstance(offset, Parameter):
            offset = offset.get()

        if default_value is not None:
            default_value = cnp.array(default_value)

        # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
        # default prereq if the output callback doesn't use them
        if prerequisites_of_calc is None:
            if requires_inputs:
                prerequisites_of_calc = [DependencyTicket.u]
            else:
                prerequisites_of_calc = [DependencyTicket.nothing]

        if period is None:
            event = None
            _output_callback = self.wrap_callback(
                callback, collect_inputs=requires_inputs
            )
            cache_index = None

        else:
            # The output port will be of "sample-and-hold" type, so we have to declare a
            # periodic event to update the value.  The callback will be used to define the
            # update event, and the output callback will simply return the stored value.

            # This is the index that this port value will have in state.cache
            cache_index = len(self._default_cache)

            def _output_callback(context: ContextBase) -> Array:
                state = context[self.system_id].state
                return state.cache[cache_index]

            def _update_callback(
                time: Scalar, state: LeafState, *inputs, **parameters
            ) -> LeafState:
                output = callback(time, state, *inputs, **parameters)
                return state.with_cached_value(cache_index, output)

            _update_callback = self.wrap_callback(
                _update_callback, collect_inputs=requires_inputs
            )

            # Create the associated update event
            event = DiscreteUpdateEvent(
                system_id=self.system_id,
                event_data=PeriodicEventData(
                    period=period, offset=offset, active=False
                ),
                name=f"{self.name}:output_{cache_index}",
                callback=_update_callback,
                passthrough=self._passthrough,
            )

            self._default_cache.append(default_value)

            # Note that in this case the "prerequisites of calc" will correspond to the
            # prerequisites of the update event, not the literal output callback itself.
            # However, these can be used to determine dependencies for the update event
            # via the output port.

        return super().declare_output_port(
            _output_callback,
            name=name,
            prerequisites_of_calc=prerequisites_of_calc,
            default_value=default_value,
            event=event,
            cache_index=cache_index,
        )

    def declare_continuous_state_output(
        self,
        name: str = None,
    ) -> int:
        """Declare a continuous state output port in the system.

        This method creates a new block-level output port which returns the full
        continuous state of the system.

        Args:
            name (str, optional):
                The name of the output port. Defaults to None (autogenerate name).

        Returns:
            int: The index of the new output port.
        """

        def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
            return state.continuous_state

        return self.declare_output_port(
            _callback,
            name=name,
            prerequisites_of_calc=[DependencyTicket.xc],
            default_value=self._default_continuous_state,
            requires_inputs=False,
        )

    def declare_mode_output(self, name: str = None) -> int:
        """Declare a mode output port in the system.

        This method creates a new block-level output port which returns the component
        of the system's state corresponding to the discrete "mode" or "stage".

        Args:
            name (str, optional):
                The name of the output port. Defaults to None.

        Returns:
            int:
                The index of the declared mode output port.
        """

        def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
            return state.mode

        return self.declare_output_port(
            _callback,
            name=name,
            prerequisites_of_calc=[DependencyTicket.mode],
            default_value=self._default_mode,
            requires_inputs=False,
        )

    #
    # Event declaration
    #
    def declare_periodic_update(
        self,
        callback: Callable,
        period: Scalar | Parameter,
        offset: Scalar | Parameter,
        enable_tracing: bool = None,
    ):
        """Declare a periodic discrete update event.

        The event will be triggered at regular intervals defined by the period and
        offset parameters. The callback should have the signature
        `callback(time, state, *inputs, **params) -> xd_plus`, where `xd_plus` is the
        updated value of the discrete state.

        This callback should be written to compute the "plus" value of the discrete
        state component given the "minus" values of all state components and inputs.

        Args:
            callback (Callable):
                The callback function defining the update.
            period (Scalar):
                The period at which the update event occurs.
            offset (Scalar):
                The offset at which the first occurrence of the event is triggered.
            enable_tracing (bool, optional):
                If True, enable tracing for this event. Defaults to None.
        """

        if isinstance(period, Parameter):
            period = period.get()
        if isinstance(offset, Parameter):
            offset = offset.get()

        _wrapped_callback = self.wrap_callback(callback)

        def _callback(context: ContextBase) -> LeafState:
            xd = _wrapped_callback(context)
            return context[self.system_id].state.with_discrete_state(xd)

        if enable_tracing is None:
            enable_tracing = True

        event = DiscreteUpdateEvent(
            system_id=self.system_id,
            name=f"{self.name}:periodic_update",
            event_data=PeriodicEventData(period=period, offset=offset, active=False),
            callback=_callback,
            passthrough=self._passthrough,
            enable_tracing=enable_tracing,
        )
        self._state_update_events.append(event)

    def declare_default_mode(self, mode: int):
        self._default_mode = mode

    def declare_zero_crossing(
        self,
        guard: Callable,
        reset_map: Callable = None,
        start_mode: int = None,
        end_mode: int = None,
        direction: str = "crosses_zero",
        terminal: bool = False,
        name: str = None,
        enable_tracing: bool = None,
    ):
        """Declare an event triggered by a zero-crossing of a guard function.

        Optionally, the system can also transition between discrete modes
        If `start_mode` and `end_mode` are specified, the system will transition
        from `start_mode` to `end_mode` when the event is triggered according to `guard`.
        This event will be active conditionally on `state.mode == start_mode` and when
        triggered will result in applying the reset map. In addition, the mode will be
        updated to `end_mode`.

        If `start_mode` and `end_mode` are not specified, the event will always be active
        and will not result in a mode transition.

        The guard function should have the signature:
            `guard(time, state, *inputs, **parameters) -> float`

        and the reset map should have the signature of an unrestricted update:
            `reset_map(time, state, *inputs, **parameters) -> state`

        Args:
            guard (Callable):
                The guard function which triggers updates on zero crossing.
            reset_map (Callable, optional):
                The reset map which is applied when the event is triggered. If None
                (default), no reset is applied.
            start_mode (int, optional):
                The mode or stage of the system in which the guard will be
                actively monitored. If None (default), the event will always be
                active.
            end_mode (int, optional):
                The mode or stage of the system to which the system will transition
                when the event is triggered. If start_mode is None, this is ignored.
                Otherwise it _must_ be specified, though it can be the same as
                start_mode.
            direction (str, optional):
                The direction of the zero crossing. Options are "crosses_zero"
                (default), "positive_then_non_positive", "negative_then_non_negative",
                and "edge_detection".  All except edge detection operate on continuous
                signals; edge detection operates on boolean signals and looks for a
                jump from False to True or vice versa.
            terminal (bool, optional):
                If True, the event will halt simulation if and when the zero-crossing
                occurs. If this event is triggered the reset map will still be applied
                as usual prior to termination. Defaults to False.
            name (str, optional):
                The name of the event. Defaults to None.
            enable_tracing (bool, optional):
                If True, enable tracing for this event. Defaults to None.

        Notes:
            By default the system state does not have a "mode" component, so in
            order to declare "state transitions" with non-null start and end modes,
            the user must first call `declare_default_mode` to set the default mode
            to be some integer (initial condition for the system).
        """

        logger.debug(
            f"Declaring transition for {self.name} with guard {guard} and reset map {reset_map}"
        )

        if enable_tracing is None:
            enable_tracing = True

        if start_mode is not None or end_mode is not None:
            assert (
                self._default_mode is not None
            ), "System has no mode: call `declare_default_mode` before transitions."
            assert isinstance(start_mode, int) and isinstance(end_mode, int)

        # Wrap the reset map with a mode update if necessary
        def _reset_and_update_mode(
            time: Scalar, state: LeafState, *inputs, **parameters
        ) -> LeafState:
            if reset_map is not None:
                state = reset_map(time, state, *inputs, **parameters)
            logger.debug(f"Updating mode from {state.mode} to {end_mode}")

            # If the start and end modes are declared, update the mode
            if start_mode is not None:
                logger.debug(f"Updating mode from {state.mode} to {end_mode}")
                state = state.with_mode(end_mode)

            return state

        _wrapped_guard = self.wrap_callback(guard)
        _wrapped_reset = _wrap_reset_map(
            self, _reset_and_update_mode, _wrapped_guard, terminal
        )

        event = ZeroCrossingEvent(
            system_id=self.system_id,
            guard=_wrapped_guard,
            reset_map=_wrapped_reset,
            passthrough=self._passthrough,
            direction=direction,
            is_terminal=terminal,
            name=name,
            event_data=ZeroCrossingEventData(active=True, triggered=False),
            enable_tracing=enable_tracing,
            active_mode=start_mode,
        )

        event_index = len(self._zero_crossing_events)
        self._zero_crossing_events.append(event)

        # Record the transition in the transition map (for debugging or analysis)
        if start_mode is not None:
            if start_mode not in self.transition_map:
                self.transition_map[start_mode] = []
            self.transition_map[start_mode].append((event_index, event))

    #
    # Initialization
    #
    @property
    def context_factory(self) -> LeafContextFactory:
        return LeafContextFactory(self)

    @property
    def dependency_graph_factory(self) -> LeafDependencyGraphFactory:
        return LeafDependencyGraphFactory(self)

    def create_state(self) -> LeafState:
        # Hook for context creation: get the default state for this system.
        # Users should not need to call this directly - the state will be created
        # as part of the context.  Generally, `system.create_context()` should
        # be all that's necessary for initialization.

        return LeafState(
            name=self.name,
            continuous_state=self._default_continuous_state,
            discrete_state=self._default_discrete_state,
            mode=self._default_mode,
            cache=tuple(self._default_cache),
        )


    def initialize_static_data(self, context: ContextBase):
        # Try to infer any missing default values for "sample-and-hold" output ports
        # and any other cached computations.
        cached_callbacks: list(SystemCallback) = [
            cb for cb in self.callbacks if cb.cache_index is not None
        ]

        for callback in cached_callbacks:
            i = callback.cache_index
            if self._default_cache[i] is None:
                try:
                    if isinstance(callback, OutputPort):
                        # Try to eval the callback for the _event_ (not the output
                        # port return function), which would return a value of the
                        # right data type for the output port, provided it is connected
                        _eval = callback.event.callback
                    else:
                        # If it's not an output port, the callback function evaluation
                        # should return the correct data type.
                        _eval = callback.eval

                    state: LeafState = _eval(context)
                    y = state.cache[i]
                    self._default_cache[i] = y
                    local_context = context[self.system_id].with_cached_value(i, y)
                    context = context.with_subcontext(self.system_id, local_context)
                except UpstreamEvalError:
                    logger.debug(
                        "%s.initialize_static_data: UpstreamEvalError. "
                        "Continuing without default value initialization.",
                        self.name,
                    )

        return context

    def _create_dependency_cache(self) -> dict[int, CallbackTracer]:
        cache = {}
        for source in self.callbacks:
            cache[source.callback_index] = CallbackTracer(ticket=source.ticket)
        return cache

    # Inherits docstring from SystemBase.get_feedthrough
    def get_feedthrough(self) -> List[Tuple[int, int]]:
        # NOTE: This implementation is basically a direct port of the Drake algorithm

        # If we already did this or it was set manually, return the stored value
        if self.feedthrough_pairs is not None:
            return self.feedthrough_pairs

        feedthrough = []  # Confirmed feedthrough pairs (input, output)

        # First collect all possible feedthrough pairs
        unknown: Set[Tuple[int, int]] = set()
        for iport in self.input_ports:
            for oport in self.output_ports:
                unknown.add((iport.index, oport.index))

        if len(unknown) == 0:
            return feedthrough

        # Create a local context and "cache".  The cache here just contains CallbackTracer
        # objects that can be used to trace dependencies through the system, but
        # otherwise don't store any actual values.  This is different from any "cached"
        # computations that might be stored in the state for reuse by multiple ports or
        # downstream calculations within the system.
        #
        # This cache will only contain local sources - this is fine since we're just
        # testing local input -> output paths for this system.
        cache = self._create_dependency_cache()

        original_unknown = unknown.copy()
        for pair in original_unknown:
            u, v = pair
            output_port = self.output_ports[v]
            input_port = self.input_ports[u]

            # If output prerequisites are unspecified, this tells us nothing
            if DependencyTicket.all_sources in output_port.prerequisites_of_calc:
                continue

            # Determine feedthrough dependency via cache invalidation
            cache = _mark_up_to_date(cache, output_port.callback_index)

            # Notify subscribers of a value change in the input, invalidating all
            # downstream cache values
            input_tracker = self.dependency_graph[input_port.ticket]
            cache = input_tracker.notify_subscribers(cache, self.dependency_graph)

            # If the output cache is now out of date, this is a feedthrough path
            if cache[output_port.callback_index].is_out_of_date:
                feedthrough.append(pair)

            # Regardless of the result of the caching, the pair is no longer unknown
            unknown.remove(pair)

            # Reset the output cache to out-of-date in case other inputs also
            # feed through to this output.
            cache = _mark_out_of_date(cache, output_port.callback_index)

        logger.debug(f"{self.name} feedthrough pairs: {feedthrough}")

        # Conservatively assume everything still unknown is feedthrough
        for pair in unknown:
            feedthrough.append(pair)

        self.feedthrough_pairs = feedthrough
        return self.feedthrough_pairs

declare_cache(callback, period=None, offset=0.0, name=None, prerequisites_of_calc=None, default_value=None, requires_inputs=True)

Declare a stored computation for the system.

This method accepts a callback function with the block-level signature callback(time, state, *inputs, **parameters) -> value and wraps it to have the signature callback(context) -> value

This callback can optionally be used to define a periodic update event that refreshes the cached value. Other calculations (e.g. sample-and-hold output ports) can then depend on the cached value.

Parameters:

Name Type Description Default
callback Callable

The callback function defining the cached computation.

required
period float

If not None, the callback function will be used to define a periodic update event that refreshes the value. Defaults to None.

None
offset float

The offset of the periodic update event. Defaults to 0.0. Will be ignored unless period is not None.

0.0
name str

The name of the cached value. Defaults to None.

None
default_value Array

The default value of the result, if known. Defaults to None.

None
requires_inputs bool

If True, the callback will eval input ports to gather input values. This will add a bit to compile time, so setting to False where possible is recommended. Defaults to True.

True
prerequisites_of_calc List[DependencyTicket]

The dependency tickets for the computation. Defaults to None, in which case the default is to assume dependency on either (inputs) if requires_inputs is True, or (nothing) otherwise.

None

Returns:

Name Type Description
int int

The index of the callback in system.callbacks. The cache index can recovered from system.callbacks[callback_index].cache_index.

Source code in collimator/framework/leaf_system.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def declare_cache(
    self,
    callback: Callable,
    period: float | Parameter = None,
    offset: float | Parameter = 0.0,
    name: str = None,
    prerequisites_of_calc: List[DependencyTicket] = None,
    default_value: Array = None,
    requires_inputs: bool = True,
) -> int:
    """Declare a stored computation for the system.

    This method accepts a callback function with the block-level signature
        `callback(time, state, *inputs, **parameters) -> value`
    and wraps it to have the signature
        `callback(context) -> value`

    This callback can optionally be used to define a periodic update event that
    refreshes the cached value.  Other calculations (e.g. sample-and-hold output
    ports) can then depend on the cached value.

    Args:
        callback (Callable):
            The callback function defining the cached computation.
        period (float, optional):
            If not None, the callback function will be used to define a periodic
            update event that refreshes the value. Defaults to None.
        offset (float, optional):
            The offset of the periodic update event. Defaults to 0.0.  Will be ignored
            unless `period` is not None.
        name (str, optional):
            The name of the cached value. Defaults to None.
        default_value (Array, optional):
            The default value of the result, if known. Defaults to None.
        requires_inputs (bool, optional):
            If True, the callback will eval input ports to gather input values.
            This will add a bit to compile time, so setting to False where possible
            is recommended. Defaults to True.
        prerequisites_of_calc (List[DependencyTicket], optional):
            The dependency tickets for the computation. Defaults to None, in which
            case the default is to assume dependency on either (inputs) if
            `requires_inputs` is True, or (nothing) otherwise.

    Returns:
        int: The index of the callback in `system.callbacks`.  The cache index can
            recovered from `system.callbacks[callback_index].cache_index`.
    """

    if isinstance(period, Parameter):
        period = period.get()
    if isinstance(offset, Parameter):
        offset = offset.get()

    # The index in the list of system callbacks
    callback_index = len(self.callbacks)

    # This is the index that this cached value will have in state.cache
    cache_index = len(self._default_cache)
    self._default_cache.append(default_value)

    # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
    # default prereq if the update callback doesn't use them
    if prerequisites_of_calc is None:
        if requires_inputs:
            prerequisites_of_calc = [DependencyTicket.u]
        else:
            prerequisites_of_calc = [DependencyTicket.nothing]

    def _update_callback(
        time: Scalar, state: LeafState, *inputs, **parameters
    ) -> LeafState:
        output = callback(time, state, *inputs, **parameters)
        return state.with_cached_value(cache_index, output)

    _update_callback = self.wrap_callback(
        _update_callback, collect_inputs=requires_inputs
    )

    if period is None:
        event = None

    else:
        # The cache has a periodic event updating its value defined by the callback
        event = DiscreteUpdateEvent(
            system_id=self.system_id,
            event_data=PeriodicEventData(
                period=period, offset=offset, active=False
            ),
            name=f"{self.name}:cache_update_{cache_index}_",
            callback=_update_callback,
            passthrough=self._passthrough,
        )

    if name is None:
        name = f"cache_{cache_index}"

    sys_callback = SystemCallback(
        callback=_update_callback,
        system=self,
        callback_index=callback_index,
        name=name,
        prerequisites_of_calc=prerequisites_of_calc,
        event=event,
        default_value=default_value,
        cache_index=cache_index,
    )
    self.callbacks.append(sys_callback)

    return callback_index

declare_configuration_parameters(**params)

Declare a set of "configuration" parameters for the system.

These parameters are non-numeric parameters used for block configuration. Their declaration as parameters rather than object attributes is mainly for the purpose of serialization - blocks that take boolean or string parameters can register them as configuration parameters and they will be properly serialized.

The args should be a dict of name-value pairs, where the values are either strings, bool, arrays, or Parameters.

Typical usage:

class MyBlock(LeafSystem):
    def __init__(self, param1=True, param2=1.0):
        super().__init__()
        self.declare_configuration_parameters(param1=param1, param2=param2)
Source code in collimator/framework/leaf_system.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
def declare_configuration_parameters(self, **params):
    """Declare a set of "configuration" parameters for the system.

    These parameters are non-numeric parameters used for block configuration.
    Their declaration as parameters rather than object attributes is mainly
    for the purpose of serialization - blocks that take boolean or string
    parameters can register them as configuration parameters and they will be
    properly serialized.

    The args should be a dict of name-value pairs, where the values are either
    strings, bool, arrays, or Parameters.

    Typical usage:

    ```python
    class MyBlock(LeafSystem):
        def __init__(self, param1=True, param2=1.0):
            super().__init__()
            self.declare_configuration_parameters(param1=param1, param2=param2)
    ```
    """
    for name, value in params.items():
        if isinstance(value, list):
            self._instance_parameters[name] = Parameter(value=np.array(value))
        else:
            self._instance_parameters[name] = Parameter(value=value)

declare_continuous_state(shape=None, default_value=None, dtype=None, ode=None, as_array=True, requires_inputs=True, prerequisites_of_calc=None)

Declare a continuous state component for the system.

The ode callback computes the time derivative of the continuous state based on the current time, state, and any additional inputs. If ode is not provided, a default zero vector of the same size as the continuous state is used. If provided, the ode callback should have the signature ode(time, state, *inputs, **params) -> xcdot.

Parameters:

Name Type Description Default
shape ShapeLike

The shape of the continuous state vector. Defaults to None.

None
default_value Array

The initial value of the continuous state vector. Defaults to None.

None
dtype DTypeLike

The data type of the continuous state vector. Defaults to None.

None
ode Callable

The callback for computing the time derivative of the continuous state. Should have the signature: ode(time, state, *inputs, **parameters) -> xcdot. Defaults to None.

None
as_array bool

If True, treat the default_value as an array-like (cast if necessary). Otherwise, it will be stored as the default state without modification.

True
requires_inputs bool

If True, indicates that the ODE computation requires inputs.

True
prerequisites_of_calc List[DependencyTicket]

The dependency tickets for the ODE computation. Defaults to None, in which case the assumption is a dependency on either (time, continuous state) if requires_inputs is False, otherwise (time, continuous state, inputs.

None

Raises:

Type Description
AssertionError

If neither shape nor default_value is provided.

Notes

(1) Only one of shape and default_value should be provided. If default_value is provided, it will be used as the initial value of the continuous state. If shape is provided, the initial value will be a zero vector of the given shape and specified dtype.

Source code in collimator/framework/leaf_system.py
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def declare_continuous_state(
    self,
    shape: ShapeLike = None,
    default_value: Array = None,
    dtype: DTypeLike = None,
    ode: Callable = None,
    as_array: bool = True,
    requires_inputs: bool = True,
    prerequisites_of_calc: List[DependencyTicket] = None,
):
    """Declare a continuous state component for the system.

    The `ode` callback computes the time derivative of the continuous state based on the
    current time, state, and any additional inputs. If `ode` is not provided, a default
    zero vector of the same size as the continuous state is used. If provided, the `ode`
    callback should have the signature `ode(time, state, *inputs, **params) -> xcdot`.

    Args:
        shape (ShapeLike, optional):
            The shape of the continuous state vector. Defaults to None.
        default_value (Array, optional):
            The initial value of the continuous state vector. Defaults to None.
        dtype (DTypeLike, optional):
            The data type of the continuous state vector. Defaults to None.
        ode (Callable, optional):
            The callback for computing the time derivative of the continuous state.
            Should have the signature:
                `ode(time, state, *inputs, **parameters) -> xcdot`.
            Defaults to None.
        as_array (bool, optional):
            If True, treat the default_value as an array-like (cast if necessary).
            Otherwise, it will be stored as the default state without modification.
        requires_inputs (bool, optional):
            If True, indicates that the ODE computation requires inputs.
        prerequisites_of_calc (List[DependencyTicket], optional):
            The dependency tickets for the ODE computation. Defaults to None, in
            which case the assumption is a dependency on either (time, continuous
            state) if `requires_inputs` is False, otherwise (time, continuous state,
            inputs.

    Raises:
        AssertionError:
            If neither shape nor default_value is provided.

    Notes:
        (1) Only one of `shape` and `default_value` should be provided. If `default_value`
        is provided, it will be used as the initial value of the continuous state. If
        `shape` is provided, the initial value will be a zero vector of the given shape
        and specified dtype.
    """

    if prerequisites_of_calc is None:
        prerequisites_of_calc = [DependencyTicket.time, DependencyTicket.xc]
        if requires_inputs:
            prerequisites_of_calc.append(DependencyTicket.u)

    if as_array:
        default_value = utils.make_array(
            default_value, dtype=dtype, shape=shape
        )

    logger.debug(f"In block {self.name} [{self.system_id}]: {default_value=}")

    # Tree-map the default value to ensure that it is an array-like with the
    # correct shape and dtype. This is necessary because the default value
    # may be a list, tuple, or other PyTree-structured object.
    default_value = tree_util.tree_map(cnp.asarray, default_value)

    self._default_continuous_state = default_value

    if ode is None:
        # If no ODE is specified, return a zero vector of the same size as the
        # continuous state. This will break if the continuous state is
        # a named tuple, in which case a custom ODE must be provided.
        assert as_array, "Must provide custom ODE for non-array continuous state"

        def ode(time, state, *inputs, **parameters):
            return cnp.zeros_like(default_value)

    # Wrap the ode function to accept a context and return the time derivatives.
    ode = self.wrap_callback(ode)

    # Declare the time derivative function as a system callback so that its
    # dependencies can be tracked in the system dependency graph
    self.ode_callback = SystemCallback(
        callback=ode,
        system=self,
        callback_index=len(self.callbacks),
        name=f"{self.name}_ode",
        prerequisites_of_calc=prerequisites_of_calc,
    )

    self.callbacks.append(self.ode_callback)

    # Override the default `eval_time_derivatives` to use the wrapped ODE function
    self.eval_time_derivatives = self.ode_callback.eval

declare_continuous_state_output(name=None)

Declare a continuous state output port in the system.

This method creates a new block-level output port which returns the full continuous state of the system.

Parameters:

Name Type Description Default
name str

The name of the output port. Defaults to None (autogenerate name).

None

Returns:

Name Type Description
int int

The index of the new output port.

Source code in collimator/framework/leaf_system.py
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
def declare_continuous_state_output(
    self,
    name: str = None,
) -> int:
    """Declare a continuous state output port in the system.

    This method creates a new block-level output port which returns the full
    continuous state of the system.

    Args:
        name (str, optional):
            The name of the output port. Defaults to None (autogenerate name).

    Returns:
        int: The index of the new output port.
    """

    def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
        return state.continuous_state

    return self.declare_output_port(
        _callback,
        name=name,
        prerequisites_of_calc=[DependencyTicket.xc],
        default_value=self._default_continuous_state,
        requires_inputs=False,
    )

declare_discrete_state(shape=None, default_value=None, dtype=None, as_array=True)

Declare a new discrete state component for the system.

The discrete state is a component of the system's state that can be updated at specific events, such as zero-crossings or periodic updates. Multiple discrete states can be declared, and each is associated with a unique index. The index is used to access and update the corresponding discrete state in the system's context during event handling.

The declared discrete state is initialized with either the provided default value or zeros of the correct shape and dtype.

Parameters:

Name Type Description Default
shape ShapeLike

The shape of the discrete state. Defaults to None.

None
default_value Array

The initial value of the discrete state. Defaults to None.

None
dtype DTypeLike

The data type of the discrete state. Defaults to None.

None
as_array bool

If True, treat the default_value as an array-like (cast if necessary). Otherwise, it will be stored as the default state without modification.

True

Raises:

Type Description
AssertionError

If as_array is True and neither shape nor default_value is provided.

Notes

(1) Only one of shape and default_value should be provided. If default_value is provided, it will be used as the initial value of the continuous state. If shape is provided, the initial value will be a zero vector of the given shape and specified dtype.

(2) Use declare_periodic_update to declare an update event that modifies the discrete state at a recurring interval.

Source code in collimator/framework/leaf_system.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def declare_discrete_state(
    self,
    shape: ShapeLike = None,
    default_value: Array | Parameter = None,
    dtype: DTypeLike = None,
    as_array: bool = True,
):
    """Declare a new discrete state component for the system.

    The discrete state is a component of the system's state that can be updated
    at specific events, such as zero-crossings or periodic updates. Multiple
    discrete states can be declared, and each is associated with a unique index.
    The index is used to access and update the corresponding discrete state in
    the system's context during event handling.

    The declared discrete state is initialized with either the provided default
    value or zeros of the correct shape and dtype.

    Args:
        shape (ShapeLike, optional):
            The shape of the discrete state. Defaults to None.
        default_value (Array, optional):
            The initial value of the discrete state. Defaults to None.
        dtype (DTypeLike, optional):
            The data type of the discrete state. Defaults to None.
        as_array (bool, optional):
            If True, treat the default_value as an array-like (cast if necessary).
            Otherwise, it will be stored as the default state without modification.

    Raises:
        AssertionError:
            If as_array is True and neither shape nor default_value is provided.

    Notes:
        (1) Only one of `shape` and `default_value` should be provided. If
        `default_value` is provided, it will be used as the initial value of the
        continuous state. If `shape` is provided, the initial value will be a
        zero vector of the given shape and specified dtype.

        (2) Use `declare_periodic_update` to declare an update event that
        modifies the discrete state at a recurring interval.
    """

    if isinstance(default_value, Parameter):
        default_value = default_value.get()

    if as_array:
        default_value = utils.make_array(
            default_value, dtype=dtype, shape=shape
        )

    # Tree-map the default value to ensure that it is an array-like with the
    # correct shape and dtype. This is necessary because the default value
    # may be a list, tuple, or other PyTree-structured object.
    default_value = tree_util.tree_map(cnp.asarray, default_value)

    self._default_discrete_state = default_value

declare_mode_output(name=None)

Declare a mode output port in the system.

This method creates a new block-level output port which returns the component of the system's state corresponding to the discrete "mode" or "stage".

Parameters:

Name Type Description Default
name str

The name of the output port. Defaults to None.

None

Returns:

Name Type Description
int int

The index of the declared mode output port.

Source code in collimator/framework/leaf_system.py
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
def declare_mode_output(self, name: str = None) -> int:
    """Declare a mode output port in the system.

    This method creates a new block-level output port which returns the component
    of the system's state corresponding to the discrete "mode" or "stage".

    Args:
        name (str, optional):
            The name of the output port. Defaults to None.

    Returns:
        int:
            The index of the declared mode output port.
    """

    def _callback(time: Scalar, state: LeafState, *inputs, **parameters):
        return state.mode

    return self.declare_output_port(
        _callback,
        name=name,
        prerequisites_of_calc=[DependencyTicket.mode],
        default_value=self._default_mode,
        requires_inputs=False,
    )

declare_output_port(callback, period=None, offset=0.0, name=None, prerequisites_of_calc=None, default_value=None, requires_inputs=True)

Declare an output port in the LeafSystem.

This method accepts a callback function with the block-level signature callback(time, state, *inputs, **parameters) -> value and wraps it to the signature expected by SystemBase.declare_output_port: callback(context) -> value

Parameters:

Name Type Description Default
callback Callable

The callback function defining the output port.

required
period float

If not None, the port will act as a "sample-and-hold", with the callback function used to define a periodic update event that refreshes the value that will be returned by the port. Typically this should match the update period of some associated update event in the system. Defaults to None.

None
offset float

The offset of the periodic update event. Defaults to 0.0. Will be ignored unless period is not None.

0.0
name str

The name of the output port. Defaults to None.

None
default_value Array

The default value of the output port, if known. Defaults to None.

None
requires_inputs bool

If True, the callback will eval input ports to gather input values. This will add a bit to compile time, so setting to False where possible is recommended. Defaults to True.

True
prerequisites_of_calc List[DependencyTicket]

The dependency tickets for the output port computation. Defaults to None, in which case the assumption is a dependency on either (nothing) if requires_inputs is False otherwise (inputs).

None

Returns:

Name Type Description
int int

The index of the declared output port.

Source code in collimator/framework/leaf_system.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
def declare_output_port(
    self,
    callback: Callable,
    period: float | Parameter = None,
    offset: float | Parameter = 0.0,
    name: str = None,
    prerequisites_of_calc: List[DependencyTicket] = None,
    default_value: Array | Parameter = None,
    requires_inputs: bool = True,
) -> int:
    """Declare an output port in the LeafSystem.

    This method accepts a callback function with the block-level signature
        `callback(time, state, *inputs, **parameters) -> value`
    and wraps it to the signature expected by SystemBase.declare_output_port:
        `callback(context) -> value`

    Args:
        callback (Callable):
            The callback function defining the output port.
        period (float, optional):
            If not None, the port will act as a "sample-and-hold", with the
            callback function used to define a periodic update event that refreshes
            the value that will be returned by the port. Typically this should
            match the update period of some associated update event in the system.
            Defaults to None.
        offset (float, optional):
            The offset of the periodic update event. Defaults to 0.0.  Will be ignored
            unless `period` is not None.
        name (str, optional):
            The name of the output port. Defaults to None.
        default_value (Array, optional):
            The default value of the output port, if known. Defaults to None.
        requires_inputs (bool, optional):
            If True, the callback will eval input ports to gather input values.
            This will add a bit to compile time, so setting to False where possible
            is recommended. Defaults to True.
        prerequisites_of_calc (List[DependencyTicket], optional):
            The dependency tickets for the output port computation.  Defaults to
            None, in which case the assumption is a dependency on either (nothing)
            if `requires_inputs` is False otherwise (inputs).

    Returns:
        int: The index of the declared output port.
    """

    if isinstance(default_value, Parameter):
        default_value = default_value.get()
    if isinstance(period, Parameter):
        period = period.get()
    if isinstance(offset, Parameter):
        offset = offset.get()

    if default_value is not None:
        default_value = cnp.array(default_value)

    # To help avoid unnecessary flagging of algebraic loops, trim the inputs as a
    # default prereq if the output callback doesn't use them
    if prerequisites_of_calc is None:
        if requires_inputs:
            prerequisites_of_calc = [DependencyTicket.u]
        else:
            prerequisites_of_calc = [DependencyTicket.nothing]

    if period is None:
        event = None
        _output_callback = self.wrap_callback(
            callback, collect_inputs=requires_inputs
        )
        cache_index = None

    else:
        # The output port will be of "sample-and-hold" type, so we have to declare a
        # periodic event to update the value.  The callback will be used to define the
        # update event, and the output callback will simply return the stored value.

        # This is the index that this port value will have in state.cache
        cache_index = len(self._default_cache)

        def _output_callback(context: ContextBase) -> Array:
            state = context[self.system_id].state
            return state.cache[cache_index]

        def _update_callback(
            time: Scalar, state: LeafState, *inputs, **parameters
        ) -> LeafState:
            output = callback(time, state, *inputs, **parameters)
            return state.with_cached_value(cache_index, output)

        _update_callback = self.wrap_callback(
            _update_callback, collect_inputs=requires_inputs
        )

        # Create the associated update event
        event = DiscreteUpdateEvent(
            system_id=self.system_id,
            event_data=PeriodicEventData(
                period=period, offset=offset, active=False
            ),
            name=f"{self.name}:output_{cache_index}",
            callback=_update_callback,
            passthrough=self._passthrough,
        )

        self._default_cache.append(default_value)

        # Note that in this case the "prerequisites of calc" will correspond to the
        # prerequisites of the update event, not the literal output callback itself.
        # However, these can be used to determine dependencies for the update event
        # via the output port.

    return super().declare_output_port(
        _output_callback,
        name=name,
        prerequisites_of_calc=prerequisites_of_calc,
        default_value=default_value,
        event=event,
        cache_index=cache_index,
    )

declare_periodic_update(callback, period, offset, enable_tracing=None)

Declare a periodic discrete update event.

The event will be triggered at regular intervals defined by the period and offset parameters. The callback should have the signature callback(time, state, *inputs, **params) -> xd_plus, where xd_plus is the updated value of the discrete state.

This callback should be written to compute the "plus" value of the discrete state component given the "minus" values of all state components and inputs.

Parameters:

Name Type Description Default
callback Callable

The callback function defining the update.

required
period Scalar

The period at which the update event occurs.

required
offset Scalar

The offset at which the first occurrence of the event is triggered.

required
enable_tracing bool

If True, enable tracing for this event. Defaults to None.

None
Source code in collimator/framework/leaf_system.py
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
def declare_periodic_update(
    self,
    callback: Callable,
    period: Scalar | Parameter,
    offset: Scalar | Parameter,
    enable_tracing: bool = None,
):
    """Declare a periodic discrete update event.

    The event will be triggered at regular intervals defined by the period and
    offset parameters. The callback should have the signature
    `callback(time, state, *inputs, **params) -> xd_plus`, where `xd_plus` is the
    updated value of the discrete state.

    This callback should be written to compute the "plus" value of the discrete
    state component given the "minus" values of all state components and inputs.

    Args:
        callback (Callable):
            The callback function defining the update.
        period (Scalar):
            The period at which the update event occurs.
        offset (Scalar):
            The offset at which the first occurrence of the event is triggered.
        enable_tracing (bool, optional):
            If True, enable tracing for this event. Defaults to None.
    """

    if isinstance(period, Parameter):
        period = period.get()
    if isinstance(offset, Parameter):
        offset = offset.get()

    _wrapped_callback = self.wrap_callback(callback)

    def _callback(context: ContextBase) -> LeafState:
        xd = _wrapped_callback(context)
        return context[self.system_id].state.with_discrete_state(xd)

    if enable_tracing is None:
        enable_tracing = True

    event = DiscreteUpdateEvent(
        system_id=self.system_id,
        name=f"{self.name}:periodic_update",
        event_data=PeriodicEventData(period=period, offset=offset, active=False),
        callback=_callback,
        passthrough=self._passthrough,
        enable_tracing=enable_tracing,
    )
    self._state_update_events.append(event)

declare_zero_crossing(guard, reset_map=None, start_mode=None, end_mode=None, direction='crosses_zero', terminal=False, name=None, enable_tracing=None)

Declare an event triggered by a zero-crossing of a guard function.

Optionally, the system can also transition between discrete modes If start_mode and end_mode are specified, the system will transition from start_mode to end_mode when the event is triggered according to guard. This event will be active conditionally on state.mode == start_mode and when triggered will result in applying the reset map. In addition, the mode will be updated to end_mode.

If start_mode and end_mode are not specified, the event will always be active and will not result in a mode transition.

The guard function should have the signature

guard(time, state, *inputs, **parameters) -> float

and the reset map should have the signature of an unrestricted update

reset_map(time, state, *inputs, **parameters) -> state

Parameters:

Name Type Description Default
guard Callable

The guard function which triggers updates on zero crossing.

required
reset_map Callable

The reset map which is applied when the event is triggered. If None (default), no reset is applied.

None
start_mode int

The mode or stage of the system in which the guard will be actively monitored. If None (default), the event will always be active.

None
end_mode int

The mode or stage of the system to which the system will transition when the event is triggered. If start_mode is None, this is ignored. Otherwise it must be specified, though it can be the same as start_mode.

None
direction str

The direction of the zero crossing. Options are "crosses_zero" (default), "positive_then_non_positive", "negative_then_non_negative", and "edge_detection". All except edge detection operate on continuous signals; edge detection operates on boolean signals and looks for a jump from False to True or vice versa.

'crosses_zero'
terminal bool

If True, the event will halt simulation if and when the zero-crossing occurs. If this event is triggered the reset map will still be applied as usual prior to termination. Defaults to False.

False
name str

The name of the event. Defaults to None.

None
enable_tracing bool

If True, enable tracing for this event. Defaults to None.

None
Notes

By default the system state does not have a "mode" component, so in order to declare "state transitions" with non-null start and end modes, the user must first call declare_default_mode to set the default mode to be some integer (initial condition for the system).

Source code in collimator/framework/leaf_system.py
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
def declare_zero_crossing(
    self,
    guard: Callable,
    reset_map: Callable = None,
    start_mode: int = None,
    end_mode: int = None,
    direction: str = "crosses_zero",
    terminal: bool = False,
    name: str = None,
    enable_tracing: bool = None,
):
    """Declare an event triggered by a zero-crossing of a guard function.

    Optionally, the system can also transition between discrete modes
    If `start_mode` and `end_mode` are specified, the system will transition
    from `start_mode` to `end_mode` when the event is triggered according to `guard`.
    This event will be active conditionally on `state.mode == start_mode` and when
    triggered will result in applying the reset map. In addition, the mode will be
    updated to `end_mode`.

    If `start_mode` and `end_mode` are not specified, the event will always be active
    and will not result in a mode transition.

    The guard function should have the signature:
        `guard(time, state, *inputs, **parameters) -> float`

    and the reset map should have the signature of an unrestricted update:
        `reset_map(time, state, *inputs, **parameters) -> state`

    Args:
        guard (Callable):
            The guard function which triggers updates on zero crossing.
        reset_map (Callable, optional):
            The reset map which is applied when the event is triggered. If None
            (default), no reset is applied.
        start_mode (int, optional):
            The mode or stage of the system in which the guard will be
            actively monitored. If None (default), the event will always be
            active.
        end_mode (int, optional):
            The mode or stage of the system to which the system will transition
            when the event is triggered. If start_mode is None, this is ignored.
            Otherwise it _must_ be specified, though it can be the same as
            start_mode.
        direction (str, optional):
            The direction of the zero crossing. Options are "crosses_zero"
            (default), "positive_then_non_positive", "negative_then_non_negative",
            and "edge_detection".  All except edge detection operate on continuous
            signals; edge detection operates on boolean signals and looks for a
            jump from False to True or vice versa.
        terminal (bool, optional):
            If True, the event will halt simulation if and when the zero-crossing
            occurs. If this event is triggered the reset map will still be applied
            as usual prior to termination. Defaults to False.
        name (str, optional):
            The name of the event. Defaults to None.
        enable_tracing (bool, optional):
            If True, enable tracing for this event. Defaults to None.

    Notes:
        By default the system state does not have a "mode" component, so in
        order to declare "state transitions" with non-null start and end modes,
        the user must first call `declare_default_mode` to set the default mode
        to be some integer (initial condition for the system).
    """

    logger.debug(
        f"Declaring transition for {self.name} with guard {guard} and reset map {reset_map}"
    )

    if enable_tracing is None:
        enable_tracing = True

    if start_mode is not None or end_mode is not None:
        assert (
            self._default_mode is not None
        ), "System has no mode: call `declare_default_mode` before transitions."
        assert isinstance(start_mode, int) and isinstance(end_mode, int)

    # Wrap the reset map with a mode update if necessary
    def _reset_and_update_mode(
        time: Scalar, state: LeafState, *inputs, **parameters
    ) -> LeafState:
        if reset_map is not None:
            state = reset_map(time, state, *inputs, **parameters)
        logger.debug(f"Updating mode from {state.mode} to {end_mode}")

        # If the start and end modes are declared, update the mode
        if start_mode is not None:
            logger.debug(f"Updating mode from {state.mode} to {end_mode}")
            state = state.with_mode(end_mode)

        return state

    _wrapped_guard = self.wrap_callback(guard)
    _wrapped_reset = _wrap_reset_map(
        self, _reset_and_update_mode, _wrapped_guard, terminal
    )

    event = ZeroCrossingEvent(
        system_id=self.system_id,
        guard=_wrapped_guard,
        reset_map=_wrapped_reset,
        passthrough=self._passthrough,
        direction=direction,
        is_terminal=terminal,
        name=name,
        event_data=ZeroCrossingEventData(active=True, triggered=False),
        enable_tracing=enable_tracing,
        active_mode=start_mode,
    )

    event_index = len(self._zero_crossing_events)
    self._zero_crossing_events.append(event)

    # Record the transition in the transition map (for debugging or analysis)
    if start_mode is not None:
        if start_mode not in self.transition_map:
            self.transition_map[start_mode] = []
        self.transition_map[start_mode].append((event_index, event))

wrap_callback(callback, collect_inputs=True)

Wrap an update function to unpack local variables and block inputs.

The callback should have the signature callback(time, state, *inputs, **params) -> result and will be wrapped to have the signature callback(context) -> result, as expected by the event handling logic.

This is used internally for declaration methods like declare_periodic_update so that users can write more intuitive block-level update functions without worrying about the "context", and have them automatically wrapped to have the right interface. It can also be called directly by users to wrap their own update functions, for example to create a callback function for declare_output_port.

The context and state are strictly immutable, so the callback should not attempt to change any values in the context or state. Even in cases where it is impossible to enforce this (e.g. a state component is a list, which is always mutable in Python), the callback should be careful to avoid direct modification of the context or state, which may lead to unexpected behavior or JAX tracer errors.

Parameters:

Name Type Description Default
callback Callable

The (pure) function to be wrapped. See above for expected signature.

required
collect_inputs bool

If True, the callback will eval input ports to gather input values. Normally this should be True, but it can be set to False if the return value depends only on the state but not inputs, for instance. This helps reduce the number of expressions that need to be JIT compiled. Default is True.

True

Returns:

Name Type Description
Callable Callable

The wrapped function, with signature callback(context) -> result.

Source code in collimator/framework/leaf_system.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def wrap_callback(
    self, callback: Callable, collect_inputs: bool = True
) -> Callable:
    """Wrap an update function to unpack local variables and block inputs.

    The callback should have the signature
    `callback(time, state, *inputs, **params) -> result`
    and will be wrapped to have the signature `callback(context) -> result`,
    as expected by the event handling logic.

    This is used internally for declaration methods like
    `declare_periodic_update` so that users can write more intuitive
    block-level update functions without worrying about the "context", and have
    them automatically wrapped to have the right interface.  It can also be
    called directly by users to wrap their own update functions, for example to
    create a callback function for `declare_output_port`.

    The context and state are strictly immutable, so the callback should not
    attempt to change any values in the context or state.  Even in cases where
    it is impossible to _enforce_ this (e.g. a state component is a list, which
    is always mutable in Python), the callback should be careful to avoid direct
    modification of the context or state, which may lead to unexpected behavior
    or JAX tracer errors.

    Args:
        callback (Callable):
            The (pure) function to be wrapped. See above for expected signature.
        collect_inputs (bool):
            If True, the callback will eval input ports to gather input values.
            Normally this should be True, but it can be set to False if the
            return value depends only on the state but not inputs, for
            instance. This helps reduce the number of expressions that need to
            be JIT compiled. Default is True.

    Returns:
        Callable:
            The wrapped function, with signature `callback(context) -> result`.
    """

    def _wrapped_callback(context: ContextBase) -> LeafStateComponent:
        if collect_inputs:
            inputs = self.collect_inputs(context)
        else:
            inputs = ()
        leaf_context: LeafContext = context[self.system_id]

        leaf_state = leaf_context.state
        params = leaf_context.parameters
        return callback(context.time, leaf_state, *inputs, **params)

    return _wrapped_callback

Parameter dataclass

Source code in collimator/framework/parameter.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
@dataclasses.dataclass
class Parameter:

    value: Union[ParameterExpr, "Parameter", ArrayLike, str, tuple]
    dtype: DTypeLike = None
    shape: ShapeLike = None

    # name is used by reference submodels, model parameters and init script
    # variables so that they can be referred to in other fields
    # (we need this for serialization).
    name: str = None

    # For complex parameter values, we can specify a Python expression as string
    # This is useful for expressions like "np.eye(p)" where p is a parameter.
    is_python_expr: bool = False
    globals_: dict = None
    locals_: dict = None

    def get(self):
        return ParameterCache.get(self)

    def set(self, value: Union["Parameter", ArrayLike, str, tuple]):
        ParameterCache.replace(self, value)

    def __post_init__(self):
        if isinstance(self.value, Parameter):
            ParameterCache.add_dependent(self.value, self)
        if isinstance(self.value, ParameterExpr):
            for val in self.value:
                if isinstance(val, Parameter):
                    ParameterCache.add_dependent(val, self)
        if isinstance(self.value, (list, tuple)):
            _add_dependents(self.value, self)

        ParameterCache.__dependents__[self] = set()

    def __add__(self, other):
        return _op(Ops.ADD, self, other)

    def __radd__(self, other):
        return _op(Ops.ADD, other, self)

    def __sub__(self, other):
        return _op(Ops.SUB, self, other)

    def __rsub__(self, other):
        return _op(Ops.SUB, other, self)

    def __mul__(self, other):
        return _op(Ops.MUL, self, other)

    def __rmul__(self, other):
        return _op(Ops.MUL, other, self)

    def __truediv__(self, other):
        return _op(Ops.DIV, self, other)

    def __rtruediv__(self, other):
        return _op(Ops.DIV, other, self)

    def __floordiv__(self, other):
        return _op(Ops.FLOORDIV, self, other)

    def __rfloordiv__(self, other):
        return _op(Ops.FLOORDIV, other, self)

    def __mod__(self, other):
        return _op(Ops.MOD, self, other)

    def __rmod__(self, other):
        return _op(Ops.MOD, other, self)

    def __pow__(self, other):
        return _op(Ops.POW, self, other)

    def __rpow__(self, other):
        return _op(Ops.POW, other, self)

    def __neg__(self):
        p = Parameter(value=ParameterExpr([Ops.NEG, self]))
        ParameterCache.add_dependent(self, p)
        return p

    def __pos__(self):
        p = Parameter(value=ParameterExpr([Ops.POS, self]))
        ParameterCache.add_dependent(self, p)
        return p

    def __abs__(self):
        p = Parameter(value=ParameterExpr([Ops.ABS, self]))
        ParameterCache.add_dependent(self, p)
        return p

    def __eq__(self, other):
        return _op(Ops.EQ, self, other)

    def __ne__(self, other):
        return _op(Ops.NE, self, other)

    def __lt__(self, other):
        return _op(Ops.LT, self, other)

    def __le__(self, other):
        return _op(Ops.LE, self, other)

    def __gt__(self, other):
        return _op(Ops.GT, self, other)

    def __ge__(self, other):
        return _op(Ops.GE, self, other)

    def __del__(self):
        ParameterCache.remove(self)

    def __hash__(self):
        return id(self)

    def __str__(self):

        if self.name is not None:
            return self.name

        return self.value_as_str()

    def __int__(self):
        if self.dtype is not None:
            return self.dtype(self.get())
        return int(self.get())

    def __float__(self):
        if self.dtype is not None:
            return self.dtype(self.get())
        return float(self.get())

    # FIXME: this is not working as expected - it will break some tests
    # def __bool__(self):
    #     return bool(self.get())

    def __complex__(self):
        return complex(self.get())

    def value_as_str(self):
        """Used for serialization of parameters. This string is displayed in the
        UI text fields for the parameters."""
        try:
            return _value_as_str(self.value)
        except ValueError as e:
            raise ParameterError(
                self, message=f"Invalid parameter value: {self.value}") from e

value_as_str()

Used for serialization of parameters. This string is displayed in the UI text fields for the parameters.

Source code in collimator/framework/parameter.py
546
547
548
549
550
551
552
553
def value_as_str(self):
    """Used for serialization of parameters. This string is displayed in the
    UI text fields for the parameters."""
    try:
        return _value_as_str(self.value)
    except ValueError as e:
        raise ParameterError(
            self, message=f"Invalid parameter value: {self.value}") from e

ShapeMismatchError

Bases: StaticError

Block parameters or input/outputs have mismatched shapes.

Source code in collimator/framework/error.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class ShapeMismatchError(StaticError):
    """Block parameters or input/outputs have mismatched shapes."""

    def __init__(self, expected_shape=None, actual_shape=None, **kwargs):
        super().__init__(**kwargs)
        self.expected_shape = expected_shape
        self.actual_shape = actual_shape

    def __str__(self):
        if self.expected_shape or self.actual_shape:
            return (
                f"Shape mismatch: expected {self.expected_shape}, "
                f"got {self.actual_shape}" + self._context_info()
            )
        return f"Shape mismatch{self._context_info()}"

StaticError

Bases: CollimatorError

Wraps a Python exception to record the offending block id. The original exception is found in the 'cause' field.

See collimator.framework.context_factory._check_types for use.

This is called 'static' (as opposed to say 'runtime') meaning this is for wrapping errors detected prior to running a simulation.

Source code in collimator/framework/error.py
161
162
163
164
165
166
167
168
169
170
class StaticError(CollimatorError):
    """Wraps a Python exception to record the offending block id. The original
    exception is found in the '__cause__' field.

    See collimator.framework.context_factory._check_types for use.

    This is called 'static' (as opposed to say 'runtime') meaning this is for
    wrapping errors detected prior to running a simulation."""

    pass

SystemBase dataclass

Basic building block for simulation in collimator.

NOTE: Type hints in SystemBase indicate the union between what would be returned by a LeafSystem and a Diagram. See type hints of the subclasses for the specific argument and return types.

Source code in collimator/framework/system_base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
@dataclasses.dataclass
class SystemBase:
    """Basic building block for simulation in collimator.

    NOTE: Type hints in SystemBase indicate the union between what would be returned
    by a LeafSystem and a Diagram. See type hints of the subclasses for the specific
    argument and return types.
    """

    # Generated unique ID for this system
    system_id: Hashable = dataclasses.field(default_factory=next_system_id, init=False)
    name: str = None  # Human-readable name for this system (optional)
    ui_id: str = None  # UUID of the block when loaded from JSON (optional)

    # Immediate parent of the current system (can only be a Diagram).
    # If None, _this_ is the root system.
    parent: Diagram = None

    def __post_init__(self):
        if self.name is None:
            self.name = f"{type(self).__name__}_{self.system_id}_"

        # All "cache sources" for this system. Typically these will correspond to
        # input ports, output ports, time derivative calculations, and any custom
        # "cached data" declared by the user (e.g. see ModelicaFMU block).
        self.callbacks: List[SystemCallback] = []

        # Index into SystemCallbacks for each port. For instance, input port `i` can
        # be retrieved by `self.callbacks[self.input_port_indices[i]]`. The
        # `input_ports` and `output_ports` properties give more convenient access.
        self.input_port_indices: List[int] = []
        self.output_port_indices: List[int] = []

        # Override this or set manually to provide a custom characteristic time scale
        # for the system. At the moment this is only used for zero-crossing isolation
        # in the simulator.
        self.characteristic_time = 1.0

        # A dependency graph for the system, mapping prerequisites of each calculation.
        # `None` indicates that the dependency graph has not been constructed yet.  If
        # accessed via the `dependency_graph` property, it will be constructed
        # automatically as necessary.
        self._dependency_graph: DependencyGraph = None

        # Storage of attributes for system serialization
        self._instance_parameters: dict[str, Parameter] = {}

        # If not empty, this defines the shape and data type of the numeric parameters
        # in the LeafSystem. This value will be used to initialize the context, so it
        # will also serve as the initial value unless explicitly overridden. In the
        # simplest cases, parameters could be stored as attributes of the LeafSystem,
        # but declaring them has the advantage of moving the values to the context,
        # allowing them to be traced by JAX rather than stored as static data. This
        # means they can be differentiated, vmapped, or otherwise modified without
        # re-compiling the simulation.
        self._default_parameters: dict[str, Array] = {}

        # Map from (input_port, output_port) if that pair is feedthrough
        # `None` indicates that the feedthrough is unknown for this system.
        # This will be computed automatically using the dependency graph
        # during algebraic loop detection unless it is set manually.
        # To manually set feedthrough, either declare this explicitly or
        # override `get_feedthrough`.
        self.feedthrough_pairs: List[Tuple[int, int]] = None

        # Pre-sorted list of all output update events for this system.  This will
        # be created when the associated property is first accessed.  This should
        # only need to be done for the root system.
        self._cache_update_events: EventCollection = None

    def __hash__(self) -> Hashable:
        return hash(self.system_id)

    def pprint(self, output=print) -> str:
        output(self._pprint_helper().strip())

    def _pprint_helper(self, prefix="") -> str:
        return f"{prefix}|-- {self.name}(id={self.system_id})\n"

    def post_simulation_finalize(self) -> None:
        """Finalize the system after simulation has completed.

        This is only intended for special blocks that need to clean up
        resources and close files."""

    @property
    def instance_parameters(self) -> dict[str, Parameter]:
        return self._instance_parameters

    @instance_parameters.setter
    def instance_parameters(self, value):
        self._instance_parameters = value

    #
    # Simulation interface
    #
    @abc.abstractproperty
    def has_feedthrough_side_effects(self) -> bool:
        """Check if the system includes any feedthrough calls to `io_callback`."""
        # This is a tricky one to explain and is almost always False except for a
        # PythonScript block that is not JAX traced.  Basically, if the output of
        # the system is used as an ODE right-hand-side, will it fail in the case where
        # the ODE solver defines a custom VJP?  This happens in diffrax, so for example
        # if a PythonScript block is used to compute the ODE right-hand-side, it will
        # fail with "Effects not supported in `custom_vjp`"
        pass

    @abc.abstractproperty
    def has_ode_side_effects(self) -> bool:
        """Check if the ODE RHS for the system includes any calls to `io_callback`."""
        # This flag indicates that the system `has_feedthrough_side_effects` AND that
        # signal is used as an ODE right-hand-side.  This is used to determine whether
        # a JAX ODE solver can be used to integrate the system.
        pass

    @abc.abstractproperty
    def has_continuous_state(self) -> bool:
        pass

    @abc.abstractproperty
    def has_discrete_state(self) -> bool:
        pass

    @abc.abstractproperty
    def has_zero_crossing_events(self) -> bool:
        pass

    def eval_time_derivatives(self, context: ContextBase) -> StateComponent:
        """Evaluate the continuous time derivatives for this system.

        Given the _root_ context, evaluate the continuous time derivatives,
        which must have the same PyTree structure as the continuous state.

        In principle, this can be overridden by custom implementations, but
        in general it is preferable to declare continuous states for LeafSystems
        using `declare_continuous_state`, which accepts a callback function
        that will be used to compute the derivatives. For Diagrams, the time
        derivatives are computed automatically using the callback functions for
        all child systems with continuous state.

        Args:
            context (ContextBase): root context of this system

        Returns:
            StateComponent:
                Continuous time derivatives for this system, or None if the system
                has no continuous state.
        """
        return None

    @abc.abstractmethod
    def eval_zero_crossing_updates(
        self,
        context: ContextBase,
        events: EventCollection,
    ) -> State:
        """Evaluate reset maps associated with zero-crossing events.

        Args:
            context (ContextBase):
                The context for the system, containing the current state and parameters.
            events (EventCollection):
                The collection of events to be evaluated (for example zero-crossing or
                periodic events for this system).

        Returns:
            State: The complete state with all updates applied.

        Notes:
            (1) Following the Drake definition, "unrestricted" updates are allowed to
            modify any component of the state: continuous, discrete, or mode.  These
            updates are evaluated in the order in which they were declared, so it is
            _possible_ (but should be strictly avoided) for multiple events to modify the
            same state component at the same time.

            Each update computes its results given the _current_ state of the system
            (the "minus" values) and returns the _updated_ state (the "plus" values).
            The update functions cannot access any information about the "plus" values of
            its own state or the state of any other block.  This could change in the future
            but for now it ensures consistency with Drake's discrete semantices:

            More specifically, since all unrestricted updates can modify the entire state,
            any time there are multiple unrestricted updates, the resulting states are
            ALWAYS in conflict.  For example, suppose a system has two unrestricted
            updates, `event1` and `event2`.  At time t_n, `event1` is active and `event2`
            is inactive.  First, `event1` is evaluated, and the state is updated.  Then
            `event2` is evaluated, but the state is not updated.  Which one is valid?
            Obviously, the `event1` return is valid, but how do we communicate this to JAX?
            The situation is more complicated if both `event1` and `event2` happen to be
            active.  In this case the states have to be "merged" somehow.  In the worst
            case, these two will modify the same components of the state in different ways.

            The implementation updates the state in a local copy of the context (since both
            are immutable).  This allows multiple unrestricted updates, but leaves open the
            possibility of multiple active updates modifying the state in conflicting ways.
            This should strictly be avoided by the implementer of the LeafSystem.  If it is
            at all unclear how to do this, it may be better to split the system into
            multiple blocks to be safe.

            (2) The events are evaluated conditionally on being marked "active"
            (indicating that their guard function triggered), so the entire event
            collection can be passed without filtering to active events. This is necessary
            to make the function calls work with JAX tracing, which do not allow for
            variable-sized arguments or returns.
        """
        pass

    def handle_discrete_update(
        self, events: EventCollection, context: ContextBase
    ) -> ContextBase:
        """Compute and apply active discrete updates.

        Given the _root_ context, evaluate the discrete updates, which must have the
        same PyTree structure as the discrete states of this system. This should be
        a pure function, so that it does not modify any aspect of the context in-place
        (even though it is difficult to strictly prevent this in Python).

        This will evaluate the set of events that result from declaring state or output
        update events on systems using `LeafSystem.declare_periodic_update` and
        `LeafSystem.declare_output_port` with an associated periodic update rate.

        This is intended for internal use by the simulator and should not normally need
        to be invoked directly by users. Events are evaluated conditionally on being
        marked "active", so the entire event collection can be passed without filtering
        to active events. This is necessary to make the function calls work with JAX
        tracing, which do not allow for variable-sized arguments or returns.

        For a discrete system updating at a particular rate, the update rule for a
        particular block is:

        ```
        x[n+1] = f(t[n], x[n], u[n])
        y[n]   = g(t[n], x[n], u[n])
        ```

        Additionally, the value y[n] is held constant until the next update from the
        point of view of other continuous-time or asynchronous discrete-time blocks.

        Because each output `y[n]` may in general depend on the input `u[n]` evaluated
        _at the same time_, the composite discrete update function represents a
        system of equations.  However, since algebraic loops are prohibited, the events
        can be ordered and executed sequentially to ensure that the updates are applied
        in the correct order.  This is implemented in
        `SystemBase.sorted_event_callbacks`.

        Multirate systems work in the same way, except that the events are evaluated
        conditionally on whether the current time corresponds to an update time for each
        event.

        Args:
            events (EventCollection): collection of discrete update events
            context (ContextBase): root context for this system

        Returns:
            ContextBase:
                updated context with all active updates applied to the discrete state
        """
        logger.debug(
            f"Handling {events.num_events} discrete update events at t={context.time}"
        )
        if events.has_events:
            # Sequentially apply all updates
            for event in events:
                system_id = event.system_id
                state = event.handle(context)
                local_context = context[system_id].with_state(state)
                context = context.with_subcontext(system_id, local_context)

        return context

    def handle_zero_crossings(
        self, events: EventCollection, context: ContextBase
    ) -> ContextBase:
        """Compute and apply active zero-crossing events.

        This is intended for internal use by the simulator and should not normally need
        to be invoked directly by users. Events are evaluated conditionally on being
        marked "active", so the entire event collection can be passed without filtering
        to active events. This is necessary to make the function calls work with JAX
        tracing, which do not allow for variable-sized arguments or returns.

        Args:
            events (EventCollection): collection of zero-crossing events
            context (ContextBase): root context for this system

        Returns:
            ContextBase: updated context with all active zero-crossing events applied
        """
        logger.debug(
            "Handling %d state transition events at t=%s",
            events.num_events,
            context.time,
        )
        if events.num_events > 0:
            # Compute the updated state for all active events
            state = self.eval_zero_crossing_updates(context, events)

            # Apply the updates
            context = context.with_state(state)
        return context

    # Events that happen at some regular interval
    @property
    def periodic_events(self) -> FlatEventCollection:
        return self.cache_update_events + self.state_update_events

    # Events that update the discrete state of the system
    @abc.abstractproperty
    def state_update_events(self) -> FlatEventCollection:
        pass

    # Events that refresh sample-and-hold outputs of the system
    @property
    def cache_update_events(self) -> FlatEventCollection:
        if self._cache_update_events is None:
            # Sort the output update events and store in the private attribute.
            self._cache_update_events = [
                cb.event for cb in self.sorted_event_callbacks if cb.event is not None
            ]

        return FlatEventCollection(tuple(self._cache_update_events))

    @abc.abstractproperty
    def _flat_callbacks(self) -> List[SystemCallback]:
        """Get a flat list of callbacks for this and all child systems."""
        pass

    @property
    def sorted_event_callbacks(self) -> List[SystemCallback]:
        """Sort and return the event-related callbacks for this system."""
        # Collect all of the callbacks associated with cache update events.  These are
        # SystemCallback objects, so they are all associated with trackers in the
        # dependency graph.  We can use these to sort the events in execution order.
        trackers = sort_trackers([cb.tracker for cb in self._flat_callbacks])

        # Retrieve the callback associated with each tracker.
        return [tracker.cache_source for tracker in trackers]

    # Events that are triggered by a "guard" function and may induce a "reset" map
    @abc.abstractproperty
    def zero_crossing_events(self) -> EventCollection:
        pass

    @abc.abstractmethod
    def determine_active_guards(self, context: ContextBase) -> EventCollection:
        """Determine active guards for zero-crossing events.

        This method is responsible for evaluating and determining which
        zero-crossing events are active based on the current system mode
        and other conditions.  This can be overridden to flag active/inactive
        guards on a block-specific basis, for instance in a StateMachine-type
        block. By default all guards are marked active at this point unless
        the zero-crossing event was declared with a non-default `start_mode`, in
        which case the guard is activated conditionally on the current mode.

        For example, in a system with finite state transitions, where a transition from
        mode A to mode B is triggered by a guard function g_AB and the inverse
        transition is triggered by a guard function g_BA, this function would activate
        g_AB if the system is in mode A and g_BA if the system is in mode B. The other
        guard function would be inactive.  If the zero-crossing event is not associated
        with a starting mode, it is considered to be always active.

        Args:
            root_context (ContextBase):
                The root context containing the overall state and parameters.

        Returns:
            EventCollection:
                A collection of zero-crossing events with active/inactive status
                updated based on the current system mode and other conditions.
        """
        pass

    #
    # I/O ports
    #
    @property
    def input_ports(self) -> List[InputPort]:
        return [self.callbacks[i] for i in self.input_port_indices]

    def get_input_port(self, name: str) -> tuple[InputPort, int]:
        """Retrieve a specific input port by name."""
        for i, port in enumerate(self.input_ports):
            if port.name == name:
                return port, i
        raise ValueError(f"System {self.name} has no input port named {name}")

    @property
    def num_input_ports(self) -> int:
        return len(self.input_port_indices)

    @property
    def output_ports(self) -> List[OutputPort]:
        return [self.callbacks[i] for i in self.output_port_indices]

    def get_output_port(self, name: str) -> OutputPort:
        """Retrieve a specific output port by name."""
        for port in self.output_ports:
            if port.name == name:
                return port
        raise ValueError(f"System {self.name} has no output port named {name}")

    @property
    def num_output_ports(self) -> int:
        return len(self.output_port_indices)

    def eval_input(self, context: ContextBase, port_index: int = 0) -> Array:
        """Get the input for a given port.

        This works by evaluating the callback function associated with the port, which
        will "pull" the upstream output port values.

        Args:
            context (ContextBase): root context for this system
            port_index (int, optional): index into `self.input_ports`, for example
                the value returned by `declare_input_port`. Defaults to 0.

        Returns:
            Array: current input values
        """
        return self.input_ports[port_index].eval(context)

    def collect_inputs(self, context: ContextBase) -> List[Array]:
        """Collect all current inputs for this system.

        Args:
            context (ContextBase): root context for this system

        Returns:
            List[Array]: list of all current input values
        """
        return [self.eval_input(context, i) for i in range(self.num_input_ports)]

    def _eval_input_port(self, context: ContextBase, port_index: int) -> Array:
        """Evaluate an upstream input port given the _root_ context.

        Intended for internal use as a callback function. Users and developers
        should typically call `eval_input` in order to get this information. That
        method will call the callback function associated with the input port,
        which will have a reference to this method.

        Args:
            context (ContextBase): root context for this system
            port_index (int): index of the input port to evaluate on the target system

        Returns:
            Array: current input values
        """
        # A helper function to evaluate an upstream input port given the _root_ context.

        port_locator = self.input_ports[port_index].locator

        if self.parent is None:
            # This is currently the root system.  Typically root input ports should not be evaluated,
            #  but we can get here during subsystem construction (e.g. type inference).  In that case,
            #  we should just defer evaluation and rely on the graph analysis to determine that
            #  everything is connected correctly.
            # See https://collimator.atlassian.net/browse/WC-51.
            # This should not happen during simulation or root context construction.
            logger.debug(
                f"    ---> {self.name} is the root system, deferring evaluation of "
                f"{port_locator[0].name}[{port_locator[1]}]"
            )
            raise UpstreamEvalError(port_locator=(self, "in", port_index))

        # The `eval_subsystem_input_port` method is only defined for Diagrams, but
        # the parent system is guaranteed to be a Diagram if this is not the root.
        # If it is the root, it should not have any (un-fixed) input ports.
        return self.parent.eval_subsystem_input_port(context, port_locator)

    #
    # Declaration utilities
    #
    def _next_input_port_name(self, name: str = None) -> str:
        """Automatically generate a unique name for the next input port."""
        if name is not None:
            assert name != ""
            return name
        return f"u_{self.num_input_ports}"

    def _next_output_port_name(self, name: str = None) -> str:
        """Automatically generate a unique name for the next output port."""
        if name is not None:
            assert name != ""
            return name
        return f"y_{self.num_output_ports}"

    def declare_input_port(
        self,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
    ) -> int:
        """Add an input port to the system.

        Returns the corresponding index into the system input_port_indices list
        Note that this is different from the callbacks index - typically it
        will make more sense to retrieve via system.input_ports[port_index], but

        Args:
            name (str, optional): name of the new port. Defaults to None, which will
                use the default naming scheme for the system (e.g. "u_0")
            prerequisites_of_calc (List[DependencyTicket], optional): list of
                dependencies for the callback function. Defaults to None.

        Returns:
            int: port index of the newly created port in `input_ports`
        """
        port_index = self.num_input_ports
        port_name = self._next_input_port_name(name)

        for port in self.input_ports:
            assert (
                port.name != port_name
            ), f"System {self.name} already has an input port named {port.name}"

        def _callback(context: ContextBase) -> Array:
            # Given the root context, evaluate the input port using the helper function
            return self._eval_input_port(context, port_index)

        callback_index = len(self.callbacks)
        port = InputPort(
            _callback,
            system=self,
            callback_index=callback_index,
            name=port_name,
            index=port_index,
            prerequisites_of_calc=prerequisites_of_calc,
        )

        assert isinstance(port, InputPort)
        assert port.system is self
        assert port.name != ""

        # Check that name is unique
        for p in self.input_ports:
            assert (
                p.name != port.name
            ), f"System {self.name} already has an input port named {port.name}"

        self.input_port_indices.append(callback_index)
        self.callbacks.append(port)

        return port_index

    def declare_output_port(
        self,
        callback: Callable,
        name: str = None,
        prerequisites_of_calc: List[DependencyTicket] = None,
        default_value: Array = None,
        event: DiscreteUpdateEvent = None,
        cache_index: int = None,
    ) -> int:
        """Add an output port to the system.

        This output port could represent any function of the context available to
        the system, so a callback function is required.  This function should have
        the form
            `callback(context: ContextBase) -> Array`
        SystemBase implementations have some specific convenience wrappers, e.g.:
            `LeafSystem.declare_continuous_state_output`
            `Diagram.export_output`

        Common cases are:
        - Feedthrough blocks: gather inputs and return some function of the
            inputs (e.g. a gain)
        - Stateful blocks: use LeafSystem.declare_(...)_state_output_port to
            return the value of a particular state
        - Diagrams: create and export a diagram-level port to the parent system using
            the callback function associated with the system-level port

        Returns the corresponding index into the system output_port_indices list
        Note that this is different from the callbacks index - typically it
        will make more sense to retrieve via system.output_ports[port_index].

        Args:
            callback (Callable): computes the value of the output port given
                the root context.
            name (str, optional): name of the new port. Defaults to None, which will
                use the default naming scheme for the system (e.g. "y_0")
            prerequisites_of_calc (List[DependencyTicket], optional): list of
                dependencies for the callback function. Defaults to None, which will
                use the default dependencies for the system (all sources).  This may
                conservatively flag the system as having algebraic loops, so it is
                better to be specific here when possible.  This is done automatically
                in the wrapper functions like `LeafSystem.declare_(...)_output_port`
            default_value (Array, optional): A default array-like value used to seed
                the context and perform type inference, when this is known up front.
                Defaults to None, which will use information propagation through the
                graph along with type promotion to determine an appropriate value.
            event (DiscreteUpdateEvent, optional): A discrete update event associated
                with this output port that will periodically refresh the value that
                will be returned by the callback function. This makes the port act as
                a sample-and-hold rather than a direct function evaluation.
            cache_index (int, optional): Index into the cache state component
                corresponding to the output port result, if the output port is of
                periodically-updated sample-and-hold type.

        Returns:
            int: port index of the newly created port
        """
        port_index = self.num_output_ports
        port_name = self._next_output_port_name(name)

        for port in self.output_ports:
            assert (
                port.name != port_name
            ), f"System {self.name} already has an output port named {port.name}"

        if prerequisites_of_calc is None:
            prerequisites_of_calc = [DependencyTicket.all_sources]

        callback_index = len(self.callbacks)
        port = OutputPort(
            callback,
            system=self,
            callback_index=callback_index,
            name=port_name,
            index=port_index,
            prerequisites_of_calc=prerequisites_of_calc,
            default_value=default_value,
            event=event,
            cache_index=cache_index,
        )

        assert isinstance(port, OutputPort)
        assert port.system is self
        assert port.name != ""

        # Check that name is unique
        for p in self.output_ports:
            assert (
                p.name != port.name
            ), f"System {self.name} already has an output port named {port.name}"

        logger.debug(f"Adding output port {port} to {self.name}")
        self.output_port_indices.append(callback_index)
        self.callbacks.append(port)
        logger.debug(
            f"    ---> {self.name} now has {len(self.output_ports)} output ports: {self.output_ports}"
        )
        logger.debug(
            f"    ---> {self.name} now has {len(self.callbacks)} cache sources: {self.callbacks}"
        )

        return port_index

    @abc.abstractmethod
    def get_feedthrough(self) -> List[Tuple[int, int]]:
        """Determine pairs of direct feedthrough ports for this system.

        By default, the algorithm relies on the dependency tracking system to determine
        feedthrough, but this can be overridden by implementing this method directly
        in a subclass, for instance if the automatic dependency tracking is too
        conservative in determining feedthrough.

        Returns:
            List[Tuple[int, int]]:
                A list of tuples (u, v) indicating that output port v has a direct
                dependency on input port u, resulting in a feedthrough path in the system.
                The indices u and v correspond to the indices of the input and output
                ports in the system's input and outpu port lists.
        """
        pass

    #
    # Initialization
    #
    def create_context(self, **kwargs) -> ContextBase:
        """Create a new context for this system.

        The context will contain all variable information used in
        simulation/analysis/optimization, such as state and parameters.

        Returns:
            ContextBase: new context for this system
        """
        return self.context_factory(**kwargs)

    def check_types(self, context: ContextBase, error_collector: ErrorCollector = None):
        """Perform any system-specific static analysis."""
        pass

    @abc.abstractproperty
    def context_factory(self) -> ContextFactory:
        """Factory object for creating contexts for this system.

        Should not be called directly - use `system.create_context` instead.
        """
        pass

    @property
    def dependency_graph(self) -> DependencyGraph:
        """Retrieve (or create if necessary) the dependency graph for this system."""
        if self._dependency_graph is None:
            self.create_dependency_graph()
        return self._dependency_graph

    @abc.abstractproperty
    def dependency_graph_factory(self) -> DependencyGraphFactory:
        """Factory object for creating dependency graphs for this system.

        Should not be called directly - use `system.create_dependency_graph` instead.
        """
        pass

    def create_dependency_graph(self):
        """Create a dependency graph for this system."""
        if self._dependency_graph is None:
            self._dependency_graph = self.dependency_graph_factory()

    def update_dependency_graph(self):
        """Update the dependency graph for this system, if already created."""
        if self._dependency_graph is not None:
            self._dependency_graph = self.dependency_graph_factory()

    def initialize_static_data(self, context: ContextBase) -> ContextBase:
        """Initialize any context data that has to be done after context creation.

        Use this to define custom auxiliary data or type inference that doesn't
        get traced by JAX. See the `ZeroOrderHold` implementation for an example.
        Since this is only applied during context initialization, it is allowed to
        modify the context directly (or the system itself).

        Typically this should not be called outside of the ContextFactory.

        Args:
            context (ContextBase): partially initialized context for this system.
        """
        return context

    @property
    def ports(self) -> dict[str, PortBase]:
        """Dictionary of all ports in this system, indexed by name"""
        return {port.name: port for port in self.input_ports + self.output_ports}

    # Convenience functions for errors and UI logs

    @property
    def name_path(self) -> list[str]:
        """Get the human-readable path to this system. None if some names are not set."""
        if self.parent is None:
            return [self.name]  # Likely to be 'root'
        if self.parent.parent is None:
            return [self.name]  # top-level block
        return self.parent.name_path + [self.name]

    @property
    def name_path_str(self) -> str:
        """Get the human-readable path to this system as a string."""
        return ".".join(self.name_path)

    @property
    def ui_id_path(self) -> Union[list[str], None]:
        """Get the uuid node path to this system. None if some IDs are not set."""
        if self.ui_id is None:
            return None
        if self.parent is None:
            return [self.ui_id]
        if self.parent.parent is None:
            return [self.ui_id]  # top-level block
        parent_path = self.parent.ui_id_path
        if parent_path is None:
            return None
        return parent_path + [self.ui_id]

    #
    # Serialization
    #
    @property
    def default_parameters(self) -> dict[str, Parameter]:
        return self._default_parameters

    def declare_parameter(
        self,
        name: str,
        default_value: Array | Parameter = None,
        shape: ShapeLike = None,
        dtype: DTypeLike = None,
        as_array: bool = True,
    ):
        """Declare a numeric parameter for the system.

        Parameters are declared in the system and accessed through the context to
        maintain separation of data ownership. This method creates an entry in the
        system's default_parameters, recording the name, default value, and dependency
        ticket for later reference.

        The default value will be used to initialize the context, so it
        will also serve as the initial value unless explicitly overridden. In the
        simplest cases, parameters could be stored as attributes of the LeafSystem,
        but declaring them has the advantage of moving the values to the context,
        allowing them to be traced by JAX rather than stored as static data. This
        means they can be differentiated, vmapped, or otherwise modified without
        re-compiling the simulation.

        Args:
            name (str):
                The name of the parameter.
            default_value (Union[Array, Parameter], optional):
                The default value of the parameter. Parameters are used
                primarily internally for serialization and should not normally need
                to be used directly when implementing LeafSystems. Defaults to None.
            shape (ShapeLike, optional):
                The shape of the parameter. Defaults to None.
            dtype (DTypeLike, optional):
                The data type of the parameter. Defaults to None.
            as_array (bool, optional):
                If True, treat the default_value as an array-like (cast if necessary).
                Otherwise, it will be stored as the default state without modification.

        Raises:
            AssertionError:
                If the parameter with the given name is already declared.

        Notes:
            (1) Only one of `shape` and `default_value` should be provided. If
            `default_value` is provided, it will be used as the initial value of the
            continuous state. If `shape` is provided, the initial value will be a
            zero vector of the given shape and specified dtype.
        """
        assert (
            name not in self._instance_parameters
        ), f"Parameter {name} already declared"

        try:
            if isinstance(default_value, Parameter):
                self._instance_parameters[name] = default_value
            else:
                if as_array:
                    default_value = utils.make_array(
                        default_value, dtype=dtype, shape=shape
                    )
                self._instance_parameters[name] = Parameter(
                    value=default_value,
                    dtype=dtype,
                    shape=shape,
                )

            logger.debug(
                f"Adding parameter {name} to {self.name} with default: {default_value}"
            )

            self._default_parameters[name] = self._instance_parameters[name]

        except Exception as e:
            traceback.print_exc()
            raise BlockParameterError(
                "Error declaring parameter",
                system=self,
                parameter_name=name,
            ) from e

dependency_graph: DependencyGraph property

Retrieve (or create if necessary) the dependency graph for this system.

name_path: list[str] property

Get the human-readable path to this system. None if some names are not set.

name_path_str: str property

Get the human-readable path to this system as a string.

ports: dict[str, PortBase] property

Dictionary of all ports in this system, indexed by name

sorted_event_callbacks: List[SystemCallback] property

Sort and return the event-related callbacks for this system.

ui_id_path: Union[list[str], None] property

Get the uuid node path to this system. None if some IDs are not set.

check_types(context, error_collector=None)

Perform any system-specific static analysis.

Source code in collimator/framework/system_base.py
785
786
787
def check_types(self, context: ContextBase, error_collector: ErrorCollector = None):
    """Perform any system-specific static analysis."""
    pass

collect_inputs(context)

Collect all current inputs for this system.

Parameters:

Name Type Description Default
context ContextBase

root context for this system

required

Returns:

Type Description
List[Array]

List[Array]: list of all current input values

Source code in collimator/framework/system_base.py
528
529
530
531
532
533
534
535
536
537
def collect_inputs(self, context: ContextBase) -> List[Array]:
    """Collect all current inputs for this system.

    Args:
        context (ContextBase): root context for this system

    Returns:
        List[Array]: list of all current input values
    """
    return [self.eval_input(context, i) for i in range(self.num_input_ports)]

context_factory()

Factory object for creating contexts for this system.

Should not be called directly - use system.create_context instead.

Source code in collimator/framework/system_base.py
789
790
791
792
793
794
795
@abc.abstractproperty
def context_factory(self) -> ContextFactory:
    """Factory object for creating contexts for this system.

    Should not be called directly - use `system.create_context` instead.
    """
    pass

create_context(**kwargs)

Create a new context for this system.

The context will contain all variable information used in simulation/analysis/optimization, such as state and parameters.

Returns:

Name Type Description
ContextBase ContextBase

new context for this system

Source code in collimator/framework/system_base.py
774
775
776
777
778
779
780
781
782
783
def create_context(self, **kwargs) -> ContextBase:
    """Create a new context for this system.

    The context will contain all variable information used in
    simulation/analysis/optimization, such as state and parameters.

    Returns:
        ContextBase: new context for this system
    """
    return self.context_factory(**kwargs)

create_dependency_graph()

Create a dependency graph for this system.

Source code in collimator/framework/system_base.py
812
813
814
815
def create_dependency_graph(self):
    """Create a dependency graph for this system."""
    if self._dependency_graph is None:
        self._dependency_graph = self.dependency_graph_factory()

declare_input_port(name=None, prerequisites_of_calc=None)

Add an input port to the system.

Returns the corresponding index into the system input_port_indices list Note that this is different from the callbacks index - typically it will make more sense to retrieve via system.input_ports[port_index], but

Parameters:

Name Type Description Default
name str

name of the new port. Defaults to None, which will use the default naming scheme for the system (e.g. "u_0")

None
prerequisites_of_calc List[DependencyTicket]

list of dependencies for the callback function. Defaults to None.

None

Returns:

Name Type Description
int int

port index of the newly created port in input_ports

Source code in collimator/framework/system_base.py