Skip to content

Block library

collimator.library

Abs

Bases: FeedthroughBlock

Output the absolute value of the input signal.

Input ports

None

Output ports

(0) The absolute value of the input signal.

Events

An event is triggered when the output changes from positive to negative or vice versa.

Source code in collimator/library/primitives.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class Abs(FeedthroughBlock):
    """Output the absolute value of the input signal.

    Input ports:
        None

    Output ports:
        (0) The absolute value of the input signal.

    Events:
        An event is triggered when the output changes from positive to negative
        or vice versa.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(cnp.abs, *args, **kwargs)

    def _zero_crossing(self, _time, _state, u):
        return u

    def initialize_static_data(self, context):
        # Add a zero-crossing event so ODE solvers can't try to integrate
        # through a discontinuity. For efficiency, only do this if the output is
        # fed to an ODE.
        if not self.has_zero_crossing_events and is_discontinuity(self.output_ports[0]):
            self.declare_zero_crossing(self._zero_crossing, direction="crosses_zero")

        return super().initialize_static_data(context)

Adder

Bases: ReduceBlock

Computes the sum/difference of the input.

The add/subtract operation can be switched by setting the operators parameter. For example, a 3-input block specified as Adder(3, operators="+-+") would add the first and third inputs and subtract the second input.

Input ports

(0..n_in-1) The input signals to add/subtract.

Output ports

(0) The sum/difference of the input signals.

Source code in collimator/library/primitives.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class Adder(ReduceBlock):
    """Computes the sum/difference of the input.

    The add/subtract operation can be switched by setting the `operators` parameter.
    For example, a 3-input block specified as `Adder(3, operators="+-+")` would add
    the first and third inputs and subtract the second input.

    Input ports:
        (0..n_in-1) The input signals to add/subtract.

    Output ports:
        (0) The sum/difference of the input signals.
    """

    @parameters(static=["operators"])
    def __init__(self, n_in, *args, operators=None, **kwargs):
        super().__init__(n_in, None, *args, **kwargs)

    def initialize(self, operators):
        if operators is not None and any(char not in {"+", "-"} for char in operators):
            raise BlockParameterError(
                message=f"Adder block {self.name} has invalid operators {operators}. Can only contain '+' and '-'",
                system=self,
                parameter_name="operators",
            )

        if operators is None:
            _func = sum
        else:
            signs = [1 if op == "+" else -1 for op in operators]

            def _func(inputs):
                signed_inputs = [s * u for (s, u) in zip(signs, inputs)]
                return sum(signed_inputs)

        self.replace_op(_func)

Arithmetic

Bases: ReduceBlock

Performs addition, subtraction, multiplication, and division on the input.

The arithmetic operation is determined by setting the operators parameter. For example, a 4-input block specified as Arithmetic(4, operators="+-*/") would: - Add the first input, - Subtract the second input, - Multiply the third input, - Divide by the fourth input.

Input ports

(0..n_in-1) The input signals for the specified arithmetic operations.

Output ports

(0) The result of the specified arithmetic operations on the input signals.

Source code in collimator/library/primitives.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class Arithmetic(ReduceBlock):
    """Performs addition, subtraction, multiplication, and division on the input.

    The arithmetic operation is determined by setting the `operators` parameter.
    For example, a 4-input block specified as `Arithmetic(4, operators="+-*/")` would:
        - Add the first input,
        - Subtract the second input,
        - Multiply the third input,
        - Divide by the fourth input.

    Input ports:
        (0..n_in-1) The input signals for the specified arithmetic operations.

    Output ports:
        (0) The result of the specified arithmetic operations on the input signals.

    """

    @parameters(static=["operators"])
    def __init__(self, n_in, *args, operators=None, **kwargs):
        super().__init__(n_in, None, *args, **kwargs)

    def initialize(self, operators):
        if operators is not None and any(
            char not in {"+", "-", "*", "/"} for char in operators
        ):
            raise BlockParameterError(
                message=f"Arithmetic block {self.name} has invalid operators {operators}. Can only contain '+', '-', '*', '/'.",
                system=self,
                parameter_name="operators",
            )

        ops = {
            "+": cnp.add,
            "-": cnp.subtract,
            "/": cnp.divide,
            "*": cnp.multiply,
        }

        def evaluate_expression(operands, operators):
            operands = operands[:]
            operators = operators[:]

            # Handle multiplication and division
            while "*" in operators or "/" in operators:
                for op in ("*", "/"):
                    if op in operators:
                        index = operators.index(op)
                        result = ops[op](operands[index], operands[index + 1])
                        operands = operands[:index] + [result] + operands[index + 2 :]
                        operators = operators[:index] + operators[index + 1 :]

            # Handle addition and subtraction
            while "+" in operators or "-" in operators:
                for op in ("-", "+"):
                    if op in operators:
                        index = operators.index(op)
                        result = ops[op](operands[index], operands[index + 1])
                        operands = operands[:index] + [result] + operands[index + 2 :]
                        operators = operators[:index] + operators[index + 1 :]

            return operands[0]

        def _func(inputs):
            inputs = list(inputs)
            if operators[0] == "/":
                inputs[0] = 1.0 / inputs[0]
            if operators[0] == "-":
                inputs[0] = -inputs[0]
            ops = operators[1:]
            return evaluate_expression(inputs, ops)

        self.replace_op(_func)

BatteryCell

Bases: LeafSystem

Dynamic electro-checmical Li-ion cell model.

Based on Tremblay and Dessaint (2009).

By using appropriate parameters, the cell model can be used to model a battery pack with the assumption that the cells of the pack behave as a single unit.

Parameters E0, K, A, below are abstract parameters used in the model presented in the reference paper. As described in the reference paper, these parameters can be extracted from typical cell manufacturer datasheets; see section 3. Section 3 also provides a table of example values for these parameters.

Input ports

(0) The current (A) flowing through the cell. Positive is discharge.

Output ports

(0) The voltage across the cell terminals (V) (1) The state of charge of the cell (normalized between 0 and 1)

Parameters:

Name Type Description Default
E0 float

described as "battery constant voltage (V)" by the reference paper.

3.366
K float

described as "polarization constant (V/Ah)" by the reference paper.

0.0076
Q float

battery capacity in Ah

2.3
R float

internal resistance (Ohms)

0.01
A float

described as "exponential zone amplitude (V)" by the reference paper.

0.26422
B float

described as "exponential zone time constant inverse (1/Ah)" by the reference paper.

26.5487
initial_SOC float

initial state of charge, normalized between 0 and 1.

1.0
Source code in collimator/library/battery_cell.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class BatteryCell(LeafSystem):
    """Dynamic electro-checmical Li-ion cell model.

    Based on [Tremblay and Dessaint (2009)](https://doi.org/10.3390/wevj3020289).

    By using appropriate parameters, the cell model can be used to model a battery pack
    with the assumption that the cells of the pack behave as a single unit.

    Parameters E0, K, A, below are abstract parameters used in the model presented in
    the reference paper. As described in the reference paper, these parameters can be
    extracted from typical cell manufacturer datasheets; see section 3. Section 3 also
    provides a table of example values for these parameters.

    Input ports:
        (0) The current (A) flowing through the cell. Positive is discharge.

    Output ports:
        (0) The voltage across the cell terminals (V)
        (1) The state of charge of the cell (normalized between 0 and 1)

    Parameters:
        E0: described as "battery constant voltage (V)" by the reference paper.
        K: described as "polarization constant (V/Ah)" by the reference paper.
        Q: battery capacity in Ah
        R: internal resistance (Ohms)
        A: described as "exponential zone amplitude (V)" by the reference paper.
        B:
            described as "exponential zone time constant inverse (1/Ah)" by the
            reference paper.
        initial_SOC: initial state of charge, normalized between 0 and 1.
    """

    class BatteryStateType(NamedTuple):
        soc: float
        i_star: float
        i_lb: float

    class FirstOrderFilter(NamedTuple):
        A: float
        B: float
        C: float

    @parameters(dynamic=["E0", "K", "Q", "R", "tau", "A", "B"], static=["initial_SOC"])
    def __init__(
        self,
        E0: float = 3.366,
        K: float = 0.0076,
        Q: float = 2.3,
        R: float = 0.01,
        tau: float = 30.0,
        A: float = 0.26422,
        B: float = 26.5487,
        initial_SOC: float = 1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.declare_input_port()  # Current flowing through the cell

        self.declare_output_port(
            self._voltage_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            name="voltage",
        )

        self.declare_output_port(
            self._soc_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            name="soc",
        )

    def initialize(self, E0, K, Q, R, tau, A, B, initial_SOC):
        # Filter for input current
        self.current_filter = self.FirstOrderFilter(-0.05, 1.0, 0.05)

        # Filter for loop-breaker
        self.lb_filter = self.FirstOrderFilter(-10.0, 1.0, 10.0)

        initial_state = self.BatteryStateType(
            soc=initial_SOC,
            i_star=0.0,  # Filtered input current
            i_lb=0.0,  # Filtered current for loop-breaking
        )

        self.declare_continuous_state(
            default_value=initial_state,
            as_array=False,
            ode=self._ode,
        )

    def _ode(self, _time, state, *inputs, **parameters) -> BatteryStateType:
        xc = state.continuous_state
        Q = parameters["Q"]

        (u,) = inputs

        soc_der_unsat = -u / (Q * Ah_to_As)

        # SoC must be between 0 and 1
        llim_violation = (xc.soc <= 0.0) & (soc_der_unsat < 0.0)
        ulim_violation = (xc.soc >= 1.0) & (soc_der_unsat > 0.0)

        # Saturated time derivative
        soc_der = cnp.where(llim_violation | ulim_violation, 0.0, soc_der_unsat)

        # Derivative of istar, the filtered current signal
        i_star_der = self.current_filter.A * xc.i_star + self.current_filter.B * u

        # Derivative of ilb, the filtered current signal for loop-breaking
        i_lb_der = self.lb_filter.A * xc.i_lb + self.lb_filter.B * u

        return self.BatteryStateType(
            soc=soc_der,
            i_star=i_star_der,
            i_lb=i_lb_der,
        )

    def _voltage_output(self, _time, state, *_inputs, **parameters) -> Array:
        E0 = parameters["E0"]
        Q = parameters["Q"]
        K = parameters["K"]
        A = parameters["A"]
        B = parameters["B"]
        R = parameters["R"]
        xc = state.continuous_state

        # Filtered input current
        i_star = self.current_filter.C * xc.i_star

        # Loop-breaking current
        i_lb = self.lb_filter.C * xc.i_lb

        # Apply limits to state of charge
        soc = cnp.clip(xc.soc, 0.0, 1.0)

        # Undo normalization by Q - this is ∫i*dt, the integral of current
        i_int = Q * (1 - soc)

        chg_mode_Q_gain = 0.1
        vdyn_den = cnp.where(i_star >= 0, Q - i_int, i_int + chg_mode_Q_gain * Q)
        vdyn = i_star * K * Q / vdyn_den

        vbatt_ulim = 2 * E0  # Reasonable upper limit on battery voltage
        vbatt_presat = (
            E0 - R * i_lb - i_int * K * Q / (Q - i_int) + A * cnp.exp(-B * i_int) - vdyn
        )
        return cnp.clip(vbatt_presat, 0.0, vbatt_ulim)

    def _soc_output(self, _time, state, *_inputs, **_parameters) -> Array:
        return state.continuous_state.soc

Chirp

Bases: SourceBlock

Produces a signal like the linear method of

https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.chirp.html

Parameters:

Name Type Description Default
f0 float

Frequency (Hz) at time t=phi.

required
f1 float

Frequency (Hz) at time t=stop_time.

required
stop_time float

Time to end the signal (seconds).

required
phi float

Phase offset (radians).

0.0
Input ports

None

Output ports

(0) The chirp signal.

Source code in collimator/library/primitives.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class Chirp(SourceBlock):
    """Produces a signal like the linear method of

    https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.chirp.html

    Parameters:
        f0 (float): Frequency (Hz) at time t=phi.
        f1 (float): Frequency (Hz) at time t=stop_time.
        stop_time (float): Time to end the signal (seconds).
        phi (float): Phase offset (radians).

    Input ports:
        None

    Output ports:
        (0) The chirp signal.
    """

    @parameters(dynamic=["f0", "f1", "stop_time", "phi"])
    def __init__(self, f0, f1, stop_time, phi=0.0, **kwargs):
        super().__init__(None, **kwargs)

    def initialize(self, f0, f1, stop_time, phi):
        # FIXME: There's an extra factor of 2 that doesn't seem like it's in the SciPy version.
        def _func(time, stop_time, f0, f1, phi):
            f = f0 + (f1 - f0) * time / (2 * stop_time)
            return cnp.cos(f * time + phi)

        self.replace_op(_func)

Clock

Bases: SourceBlock

Source block returning simulation time.

Input ports

None

Output ports

(0) The simulation time.

Parameters:

Name Type Description Default
dtype

The data type of the output signal. The default is "None", which will default to the current default floating point precision

None
Source code in collimator/library/primitives.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class Clock(SourceBlock):
    """Source block returning simulation time.

    Input ports:
        None

    Output ports:
        (0) The simulation time.

    Parameters:
        dtype:
            The data type of the output signal.  The default is "None", which will
            default to the current default floating point precision
    """

    def __init__(self, dtype=None, **kwargs):
        super().__init__(lambda t: cnp.array(t, dtype=dtype), **kwargs)

Comparator

Bases: LeafSystem

Compare two signals using typical relational operators.

When using == and != operators, the block uses tolerances to determine if the expression is true or false.

Parameters:

Name Type Description Default
operator

one of ("==", "!=", ">=", ">", ">=", "<")

None
atol

the absolute tolerance value used with "==" or "!="

1e-05
rtol

the relative tolerance value used with "==" or "!="

1e-08
Input Ports

(0) The left side operand (1) The right side operand

Output Ports

(0) The result of the comparison (boolean signal)

Events

An event is triggered when the output changes from true to false or vice versa.

Source code in collimator/library/primitives.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
class Comparator(LeafSystem):
    """Compare two signals using typical relational operators.

    When using == and != operators, the block uses tolerances to determine if the
    expression is true or false.

    Parameters:
        operator: one of ("==", "!=", ">=", ">", ">=", "<")
        atol: the absolute tolerance value used with "==" or "!="
        rtol: the relative tolerance value used with "==" or "!="

    Input Ports:
        (0) The left side operand
        (1) The right side operand

    Output Ports:
        (0) The result of the comparison (boolean signal)

    Events:
        An event is triggered when the output changes from true to false or vice versa.
    """

    @parameters(static=["operator", "atol", "rtol"])
    def __init__(self, atol=1e-5, rtol=1e-8, operator=None, **kwargs):
        super().__init__(**kwargs)
        self.declare_input_port()
        self.declare_input_port()
        self._output_port_idx = self.declare_output_port()

    def initialize(self, atol, rtol, operator):
        func_lookup = {
            ">": cnp.greater,
            ">=": cnp.greater_equal,
            "<": cnp.less,
            "<=": cnp.less_equal,
            "==": self._equal,
            "!=": self._ne,
        }

        if operator not in func_lookup:
            message = (
                f"Comparator block '{self.name}' has invalid selection "
                + f"'{operator}' for parameter 'operator'. Valid options: "
                + ",".join([k for k in func_lookup.keys()])
            )
            raise BlockParameterError(
                message=message, system=self, parameter_name="operator"
            )

        self.rtol = rtol
        self.atol = atol

        compare = func_lookup[operator]

        def _compute_output(_time, _state, *inputs, **_params):
            return compare(*inputs)

        self.configure_output_port(
            self._output_port_idx,
            _compute_output,
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
        )
        self.evt_direction = self._process_operator(operator)

    def _equal(self, x, y):
        if cnp.issubdtype(x.dtype, cnp.floating):
            return cnp.isclose(x, y, self.rtol, self.atol)
        return x == y

    def _ne(self, x, y):
        if cnp.issubdtype(x.dtype, cnp.floating):
            return cnp.logical_not(cnp.isclose(x, y, self.rtol, self.atol))
        return x != y

    def _zero_crossing(self, _time, _state, *inputs, **_params):
        return inputs[0] - inputs[1]

    def _process_operator(self, operator):
        if operator in ["<", "<="]:
            return "positive_then_non_positive"
        if operator in [">", ">="]:
            return "negative_then_non_negative"
        return "crosses_zero"

    def initialize_static_data(self, context):
        # Add a zero-crossing event so ODE solvers can't try to integrate
        # through a discontinuity. For efficiency, only do this if the output is
        # fed to an ODE.
        if not self.has_zero_crossing_events and is_discontinuity(self.output_ports[0]):
            self.declare_zero_crossing(
                self._zero_crossing, direction=self.evt_direction
            )

        return super().initialize_static_data(context)

Constant

Bases: LeafSystem

A source block that emits a constant value.

Parameters:

Name Type Description Default
value

The constant value of the block.

required
Input ports

None

Output ports

(0) The constant value.

Source code in collimator/library/primitives.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
class Constant(LeafSystem):
    """A source block that emits a constant value.

    Parameters:
        value: The constant value of the block.

    Input ports:
        None

    Output ports:
        (0) The constant value.
    """

    @parameters(dynamic=["value"])
    def __init__(self, value, *args, **kwargs):
        super().__init__(**kwargs)
        self._output_port_idx = self.declare_output_port(name="out_0")

    def initialize(self, value):
        def _func(time, state, *inputs, **parameters):
            return parameters["value"]

        self.configure_output_port(
            self._output_port_idx,
            _func,
            prerequisites_of_calc=[DependencyTicket.nothing],
            requires_inputs=False,
        )

ContinuousTimeInfiniteHorizonKalmanFilter

Bases: LeafSystem

Continuous-time Infinite Horizon Kalman Filter for the following system:

dot_x =  A x + B u + G w
y   = C x + D u + v

E(w) = E(v) = 0
E(ww') = Q
E(vv') = R
E(wv') = N = 0
Input ports

(0) u : continuous-time control vector (1) y : continuous-time measurement vector

Output ports

(1) x_hat : continuous-time state vector estimate

Parameters:

Name Type Description Default
A

ndarray State transition matrix

required
B

ndarray Input matrix

required
C

ndarray Output matrix

required
D

ndarray Feedthrough matrix

required
G

ndarray Process noise matrix

required
Q

ndarray Process noise covariance matrix

required
R

ndarray Measurement noise covariance matrix

required
x_hat_0

ndarray Initial state estimate

required
Source code in collimator/library/state_estimators/continuous_time_infinite_horizon_kalman_filter.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class ContinuousTimeInfiniteHorizonKalmanFilter(LeafSystem):
    """
    Continuous-time Infinite Horizon Kalman Filter for the following system:

    ```
    dot_x =  A x + B u + G w
    y   = C x + D u + v

    E(w) = E(v) = 0
    E(ww') = Q
    E(vv') = R
    E(wv') = N = 0
    ```

    Input ports:
        (0) u : continuous-time control vector
        (1) y : continuous-time measurement vector

    Output ports:
        (1) x_hat : continuous-time state vector estimate

    Parameters:
        A: ndarray
            State transition matrix
        B: ndarray
            Input matrix
        C: ndarray
            Output matrix
        D: ndarray
            Feedthrough matrix
        G: ndarray
            Process noise matrix
        Q: ndarray
            Process noise covariance matrix
        R: ndarray
            Measurement noise covariance matrix
        x_hat_0: ndarray
            Initial state estimate
    """

    def __init__(self, A, B, C, D, G, Q, R, x_hat_0, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.A = A
        self.B = B
        self.C = C
        self.D = D
        self.G = G
        self.Q = Q
        self.R = R

        self.nx, self.nu = B.shape
        self.ny = C.shape[0]

        L, P, E = control.lqe(A, G, C, Q, R)

        self.A_minus_LC = A - cnp.matmul(L, C)
        self.B_minus_LD = B - cnp.matmul(L, D)
        self.L = L

        self.declare_input_port()  # u
        self.declare_input_port()  # y

        self.declare_continuous_state(
            ode=self._ode, shape=x_hat_0.shape, default_value=x_hat_0, as_array=True
        )  # continuous state for x_hat

        self.declare_continuous_state_output()

    def _ode(self, time, state, *inputs, **params):
        x_hat = state.continuous_state

        u, y = inputs

        u = cnp.atleast_1d(u)
        y = cnp.atleast_1d(y)

        dot_x_hat = (
            cnp.dot(self.A_minus_LC, x_hat)
            + cnp.dot(self.B_minus_LD, u)
            + cnp.dot(self.L, y)
        )

        return dot_x_hat

    #######################################
    # Make filter for a continuous plant  #
    #######################################
    @staticmethod
    @with_resolved_parameters
    def for_continuous_plant(
        plant,
        x_eq,
        u_eq,
        Q,
        R,
        G=None,
        x_hat_bar_0=None,
        name=None,
    ):
        """
        Obtain a continuous-time Infinite Horizon Kalman Filter system for a
        continuous-time plant after linearization at equilibrium point (x_eq, u_eq)

        The input plant contains the deterministic forms of the forward and observation
        operators:

        ```
            dx/dt = f(x,u)
            y = g(x,u)
        ```

        Note: Only plants with one vector-valued input and one vector-valued output
        are currently supported. Furthermore, the plant LeafSystem/Diagram should have
        only one vector-valued integrator.

        A plant with disturbances of the following form is then considered
        following form:

        ```
            dx/dt = f(x,u) + G w
            y = g(x,u) +  v
        ```

        where:

            `w` represents the process noise,
            `v` represents the measurement noise,

        and

        ```
            E(w) = E(v) = 0
            E(ww') = Q
            E(vv') = R
            E(wv') = N = 0
        ```

        This plant with disturbances is linearized (only `f` and `q`) around the
        equilibrium point to obtain:

        ```
            d/dt (x_bar) = A x_bar + B u_bar + G w    --- (C1)
            y_bar = C x_bar + D u_bar + v             --- (C2)
        ```

        where,

        ```
            x_bar = x - x_eq
            u_bar = u - u_eq
            y_bar = y - y_bar
            y_eq = g(x_eq, u_eq)
        ```

        A continuous-time Kalman Filter estimator for the system of equations (C1) and
        (C2) is returned. This filter is in the `x_bar`, `u_bar`, and `y_bar`
        states.

        The returned system will have

        Input ports:
            (0) u_bar : continuous-time control vector relative to equilibrium point
            (1) y_bar : continuous-time measurement vector relative to equilibrium point

        Output ports:
            (1) x_hat_bar : continuous-time state vector estimate relative to
                            equilibrium point

        Parameters:
            plant : a `Plant` object which can be a LeafSystem or a Diagram.
            x_eq: ndarray
                Equilibrium state vector for discretization
            u_eq: ndarray
                Equilibrium control vector for discretization
            Q: ndarray
                Process noise covariance matrix.
            R: ndarray
                Measurement noise covariance matrix.
            G: ndarray
                Process noise matrix. If `None`, `G=B` is assumed making disrurbances
                additive to control vector `u`, i.e. `u_disturbed = u_orig + w`.
            x_hat_bar_0: ndarray
                Initial state estimate relative to equilibrium point.
                If None, an identity matrix is assumed.
        """

        y_eq, linear_plant = linearize_plant(plant, x_eq, u_eq)

        A, B, C, D = linear_plant.A, linear_plant.B, linear_plant.C, linear_plant.D

        nx, nu = B.shape
        ny, _ = D.shape

        if G is None:
            G = B

        if x_hat_bar_0 is None:
            x_hat_bar_0 = cnp.zeros(nx)

        # Instantiate a Kalman Filter instance for the linearized plant
        kf = ContinuousTimeInfiniteHorizonKalmanFilter(
            A,
            B,
            C,
            D,
            G,
            Q,
            R,
            x_hat_bar_0,
            name=name,
        )

        return y_eq, kf

for_continuous_plant(plant, x_eq, u_eq, Q, R, G=None, x_hat_bar_0=None, name=None) staticmethod

Obtain a continuous-time Infinite Horizon Kalman Filter system for a continuous-time plant after linearization at equilibrium point (x_eq, u_eq)

The input plant contains the deterministic forms of the forward and observation operators:

    dx/dt = f(x,u)
    y = g(x,u)

Note: Only plants with one vector-valued input and one vector-valued output are currently supported. Furthermore, the plant LeafSystem/Diagram should have only one vector-valued integrator.

A plant with disturbances of the following form is then considered following form:

    dx/dt = f(x,u) + G w
    y = g(x,u) +  v

where:

`w` represents the process noise,
`v` represents the measurement noise,

and

    E(w) = E(v) = 0
    E(ww') = Q
    E(vv') = R
    E(wv') = N = 0

This plant with disturbances is linearized (only f and q) around the equilibrium point to obtain:

    d/dt (x_bar) = A x_bar + B u_bar + G w    --- (C1)
    y_bar = C x_bar + D u_bar + v             --- (C2)

where,

    x_bar = x - x_eq
    u_bar = u - u_eq
    y_bar = y - y_bar
    y_eq = g(x_eq, u_eq)

A continuous-time Kalman Filter estimator for the system of equations (C1) and (C2) is returned. This filter is in the x_bar, u_bar, and y_bar states.

The returned system will have

Input ports

(0) u_bar : continuous-time control vector relative to equilibrium point (1) y_bar : continuous-time measurement vector relative to equilibrium point

Output ports

(1) x_hat_bar : continuous-time state vector estimate relative to equilibrium point

Parameters:

Name Type Description Default
plant

a Plant object which can be a LeafSystem or a Diagram.

required
x_eq

ndarray Equilibrium state vector for discretization

required
u_eq

ndarray Equilibrium control vector for discretization

required
Q

ndarray Process noise covariance matrix.

required
R

ndarray Measurement noise covariance matrix.

required
G

ndarray Process noise matrix. If None, G=B is assumed making disrurbances additive to control vector u, i.e. u_disturbed = u_orig + w.

None
x_hat_bar_0

ndarray Initial state estimate relative to equilibrium point. If None, an identity matrix is assumed.

None
Source code in collimator/library/state_estimators/continuous_time_infinite_horizon_kalman_filter.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@staticmethod
@with_resolved_parameters
def for_continuous_plant(
    plant,
    x_eq,
    u_eq,
    Q,
    R,
    G=None,
    x_hat_bar_0=None,
    name=None,
):
    """
    Obtain a continuous-time Infinite Horizon Kalman Filter system for a
    continuous-time plant after linearization at equilibrium point (x_eq, u_eq)

    The input plant contains the deterministic forms of the forward and observation
    operators:

    ```
        dx/dt = f(x,u)
        y = g(x,u)
    ```

    Note: Only plants with one vector-valued input and one vector-valued output
    are currently supported. Furthermore, the plant LeafSystem/Diagram should have
    only one vector-valued integrator.

    A plant with disturbances of the following form is then considered
    following form:

    ```
        dx/dt = f(x,u) + G w
        y = g(x,u) +  v
    ```

    where:

        `w` represents the process noise,
        `v` represents the measurement noise,

    and

    ```
        E(w) = E(v) = 0
        E(ww') = Q
        E(vv') = R
        E(wv') = N = 0
    ```

    This plant with disturbances is linearized (only `f` and `q`) around the
    equilibrium point to obtain:

    ```
        d/dt (x_bar) = A x_bar + B u_bar + G w    --- (C1)
        y_bar = C x_bar + D u_bar + v             --- (C2)
    ```

    where,

    ```
        x_bar = x - x_eq
        u_bar = u - u_eq
        y_bar = y - y_bar
        y_eq = g(x_eq, u_eq)
    ```

    A continuous-time Kalman Filter estimator for the system of equations (C1) and
    (C2) is returned. This filter is in the `x_bar`, `u_bar`, and `y_bar`
    states.

    The returned system will have

    Input ports:
        (0) u_bar : continuous-time control vector relative to equilibrium point
        (1) y_bar : continuous-time measurement vector relative to equilibrium point

    Output ports:
        (1) x_hat_bar : continuous-time state vector estimate relative to
                        equilibrium point

    Parameters:
        plant : a `Plant` object which can be a LeafSystem or a Diagram.
        x_eq: ndarray
            Equilibrium state vector for discretization
        u_eq: ndarray
            Equilibrium control vector for discretization
        Q: ndarray
            Process noise covariance matrix.
        R: ndarray
            Measurement noise covariance matrix.
        G: ndarray
            Process noise matrix. If `None`, `G=B` is assumed making disrurbances
            additive to control vector `u`, i.e. `u_disturbed = u_orig + w`.
        x_hat_bar_0: ndarray
            Initial state estimate relative to equilibrium point.
            If None, an identity matrix is assumed.
    """

    y_eq, linear_plant = linearize_plant(plant, x_eq, u_eq)

    A, B, C, D = linear_plant.A, linear_plant.B, linear_plant.C, linear_plant.D

    nx, nu = B.shape
    ny, _ = D.shape

    if G is None:
        G = B

    if x_hat_bar_0 is None:
        x_hat_bar_0 = cnp.zeros(nx)

    # Instantiate a Kalman Filter instance for the linearized plant
    kf = ContinuousTimeInfiniteHorizonKalmanFilter(
        A,
        B,
        C,
        D,
        G,
        Q,
        R,
        x_hat_bar_0,
        name=name,
    )

    return y_eq, kf

CoordinateRotation

Bases: LeafSystem

Computes the rotation of a 3D vector between coordinate systems.

Given sufficient information to construct a rotation matrix C_AB from orthogonal coordinate system B to orthogonal coordinate system A, along with an input vector x_B expressed in B-axes, this block will compute the matrix-vector product x_A = C_AB @ x_B.

Note that depending on the type of rotation representation, this matrix may not be explicitly computed. The types of rotations supported are Quaternion, Euler Angles, and Direction Cosine Matrix (DCM).

By default, the rotations have the following convention:

  • Quaternion: The rotation is represented by a 4-component quaternion q. The rotation is carried out by the product p_A = q⁻¹ * p_B * q, where q⁻¹ is the quaternion inverse of q, * is the quaternion product, and p_A and p_B are the quaternion extensions of the vectors x_A and x_B, i.e. p_A = [0, x_A] and p_B = [0, x_B].

  • Roll-Pitch-Yaw (Euler Angles): The rotation is represented by the set of Euler angles ϕ (roll), θ (pitch), and ψ (yaw), in the "1-2-3" convention for intrinsic rotations. The resulting rotation matrix C_AB(ϕ, θ, ψ) is the same as the product of the three single-axis rotation matrices C_AB = Cz(ψ) * Cy(θ) * Cx(ϕ).

    For example, if B represents a fixed "world" frame with axes xyz and A is a body-fixed frame with axes XYZ, then C_AB represents a rotation from the world frame to the body frame, in the following sequence:

    1. Right-hand rotation about the world frame x-axis by ϕ (roll), resulting in the intermediate frame x'y'z' with x' = x.
    2. Right-hand rotation about the intermediate frame y'-axis by θ (pitch), resulting in the intermediate frame x''y''z'' with y'' = y'.
    3. Right-hand rotation about the intermediate frame z''-axis by ψ (yaw), resulting in the body frame XYZ with z = z''.
  • Direction Cosine Matrix: The rotation is directly represented as a 3x3 matrix C_AB. The rotation is carried out by the matrix-vector product x_A = C_AB @ x_B.

Input ports

(0): The input vector x_B expressed in the B-axes.

(1): (if enable_external_rotation_definition=True) The rotation representation (quaternion, Euler angles, or cosine matrix) that defines the rotation from B to A (or A to B if inverse=True).

Output ports

(0): The output vector x_A expressed in the A-axes.

Parameters:

Name Type Description Default
rotation_type str

The type of rotation representation to use. Must be one of ("quaternion", "roll_pitch_yaw", "dcm").

required
enable_external_rotation_definition

If True, the block will have one input port for the rotation representation (quaternion, Euler angles, or cosine matrix). Otherwise the rotation must be provided as a block parameter.

True
inverse

If True, the block will compute the inverse transformation, i.e. if the matrix representation of the rotation is C_AB from frame B to frame A, the block will compute the inverse transformation C_BA = C_AB⁻¹ = C_AB.T

False
quaternion Array

The quaternion representation of the rotation if enable_external_rotation_definition=False.

None
roll_pitch_yaw Array

The Euler angles representation of the rotation if enable_external_rotation_definition=False.

None
direction_cosine_matrix Array

The direction cosine matrix representation of the rotation if enable_external_rotation_definition=False.

None
Source code in collimator/library/rotations.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class CoordinateRotation(LeafSystem):
    """Computes the rotation of a 3D vector between coordinate systems.

    Given sufficient information to construct a rotation matrix `C_AB` from orthogonal
    coordinate system `B` to orthogonal coordinate system `A`, along with an input
    vector `x_B` expressed in `B`-axes, this block will compute the matrix-vector
    product `x_A = C_AB @ x_B`.

    Note that depending on the type of rotation representation, this matrix may not be
    explicitly computed.  The types of rotations supported are Quaternion, Euler
    Angles, and Direction Cosine Matrix (DCM).

    By default, the rotations have the following convention:

    - __Quaternion:__ The rotation is represented by a 4-component quaternion `q`.
        The rotation is carried out by the product `p_A = q⁻¹ * p_B * q`, where
        `q⁻¹` is the quaternion inverse of `q`, `*` is the quaternion product, and
        `p_A` and `p_B` are the quaternion extensions of the vectors `x_A` and `x_B`,
        i.e. `p_A = [0, x_A]` and `p_B = [0, x_B]`.

    - __Roll-Pitch-Yaw (Euler Angles):__ The rotation is represented by the set of Euler angles
        ϕ (roll), θ (pitch), and ψ (yaw), in the "1-2-3" convention for intrinsic
        rotations. The resulting rotation matrix `C_AB(ϕ, θ, ψ)` is the same as the product of
        the three  single-axis rotation matrices `C_AB = Cz(ψ) * Cy(θ) * Cx(ϕ)`.

        For example, if `B` represents a fixed "world" frame with axes `xyz` and `A`
        is a body-fixed frame with axes `XYZ`, then `C_AB` represents a rotation from
        the world frame to the body frame, in the following sequence:

        1. Right-hand rotation about the world frame `x`-axis by `ϕ` (roll), resulting
            in the intermediate frame `x'y'z'` with `x' = x`.
        2. Right-hand rotation about the intermediate frame `y'`-axis by `θ` (pitch),
            resulting in the intermediate frame `x''y''z''` with `y'' = y'`.
        3. Right-hand rotation about the intermediate frame `z''`-axis by `ψ` (yaw),
            resulting in the body frame `XYZ` with `z = z''`.

    - __Direction Cosine Matrix:__ The rotation is directly represented as a
        3x3 matrix `C_AB`. The rotation is carried out by the matrix-vector product
        `x_A = C_AB @ x_B`.

    Input ports:
        (0): The input vector `x_B` expressed in the `B`-axes.

        (1): (if `enable_external_rotation_definition=True`) The rotation
            representation (quaternion, Euler angles, or cosine matrix) that defines
            the rotation from `B` to `A` (or `A` to `B` if `inverse=True`).

    Output ports:
        (0): The output vector `x_A` expressed in the `A`-axes.

    Parameters:
        rotation_type (str): The type of rotation representation to use. Must be one of
            ("quaternion", "roll_pitch_yaw", "dcm").
        enable_external_rotation_definition: If `True`, the block will have one
            input port for the rotation representation (quaternion, Euler angles, or
            cosine matrix).  Otherwise the rotation must be provided as a block
            parameter.
        inverse: If `True`, the block will compute the inverse transformation, i.e.
            if the matrix representation of the rotation is `C_AB` from frame `B` to
            frame `A`, the block will compute the inverse transformation
            `C_BA = C_AB⁻¹ = C_AB.T`
        quaternion (Array, optional): The quaternion representation of the rotation
            if `enable_external_rotation_definition=False`.
        roll_pitch_yaw (Array, optional): The Euler angles representation of the
            rotation if `enable_external_rotation_definition=False`.
        direction_cosine_matrix (Array, optional): The direction cosine matrix
            representation of the rotation if `enable_external_rotation_definition=False`.
    """

    @parameters(
        static=[
            "quaternion",
            "roll_pitch_yaw",
            "direction_cosine_matrix",
            "rotation_type",
            "enable_external_rotation_definition",
            "inverse",
        ]
    )
    def __init__(
        self,
        rotation_type,
        enable_external_rotation_definition=True,
        quaternion=None,
        roll_pitch_yaw=None,
        direction_cosine_matrix=None,
        inverse=False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.external_rotation = enable_external_rotation_definition
        self.rotation_type = rotation_type
        self.inverse = inverse

        self.vector_input_index = self.declare_input_port()

        # Note: all of the possible rotation specifications are passed as parameters
        # to make the serialization work, but only one is valid at a time. This makes
        # sense from the UI, but is a bit strange when working directly with the code.
        # In any case, the typical use case is to have the external rotation port
        # enabled, so all of these should usually be None.  If more than one is
        # provided (which can happen for instance via hidden parameters in the JSON)
        # then only the rotation corresponding to the `rotation_type` will be used, and
        # the rest will be ignored.
        rotation = self._check_config(
            rotation_type,
            quaternion,
            roll_pitch_yaw,
            direction_cosine_matrix,
        )

        if enable_external_rotation_definition:
            self.rotation_input_index = self.declare_input_port()

        else:
            # Store the static rotation as a parameter (will be None if external
            # rotation is enabled)
            self.declare_dynamic_parameter("rotation", rotation)

        self._output_port_idx = self.declare_output_port(
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
        )

    def initialize(
        self,
        rotation_type,
        enable_external_rotation_definition,
        quaternion,
        roll_pitch_yaw,
        direction_cosine_matrix,
        inverse,
        rotation=None,
    ):
        if enable_external_rotation_definition != self.external_rotation:
            raise ValueError("Cannot change external rotation definition.")

        self.rotation_type = rotation_type
        self.inverse = inverse
        if not self.external_rotation:
            rotation = self._check_config(
                rotation_type,
                quaternion,
                roll_pitch_yaw,
                direction_cosine_matrix,
            )

            def _output_func(_time, _state, *inputs, **parameters):
                vector = inputs[self.vector_input_index]
                return self._apply(rotation, vector)

        else:

            def _output_func(_time, _state, *inputs, **parameters):
                vector = inputs[self.vector_input_index]
                rotation = inputs[self.rotation_input_index]
                return self._apply(rotation, vector)

        self.configure_output_port(
            self._output_port_idx,
            _output_func,
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
        )

    def _check_config(
        self, rotation_type, quaternion, roll_pitch_yaw, direction_cosine_matrix
    ):
        if rotation_type not in ("quaternion", "roll_pitch_yaw", "DCM"):
            message = f"Invalid rotation type: {rotation_type}."
            raise BlockParameterError(
                message=message, system=self, parameter_name="rotation_type"
            )

        if self.external_rotation:
            # Input type checking will be done by `check_types`
            return

        if rotation_type == "quaternion":
            if quaternion is None:
                message = (
                    "A static quaternion must be provided if external rotation "
                    + "definition is disabled."
                )
                raise BlockParameterError(
                    message=message, system=self, parameter_name="quaternion"
                )
            rotation = cnp.asarray(quaternion)
            if rotation.shape != (4,):
                message = (
                    "The quaternion must have shape (4,), but has shape "
                    + f"{rotation.shape}."
                )
                raise BlockParameterError(
                    message=message, system=self, parameter_name="quaternion"
                )

        elif rotation_type == "roll_pitch_yaw":
            if roll_pitch_yaw is None:
                message = (
                    "A static roll-pitch-yaw sequence must be provided if external "
                    + "rotation definition is disabled."
                )
                raise BlockParameterError(
                    message=message, system=self, parameter_name="roll_pitch_yaw"
                )
            rotation = cnp.asarray(roll_pitch_yaw)
            if rotation.shape != (3,):
                message = (
                    "The Euler angles must have shape (3,), but has shape "
                    + f"{rotation.shape}."
                )
                raise BlockParameterError(
                    message=message, system=self, parameter_name="roll_pitch_yaw"
                )

        elif rotation_type == "DCM":
            if direction_cosine_matrix is None:
                message = (
                    "A static direction cosine matrix must be provided if external "
                    + "rotation definition is disabled."
                )
                raise BlockParameterError(
                    message=message,
                    system=self,
                    parameter_name="direction_cosine_matrix",
                )
            rotation = cnp.asarray(direction_cosine_matrix)
            if rotation.shape != (3, 3):
                message = (
                    "The direction cosine matrix must have shape (3, 3), but has shape "
                    + f"{rotation.shape}."
                )
                raise BlockParameterError(
                    message=message,
                    system=self,
                    parameter_name="direction_cosine_matrix",
                )

        return rotation

    def _apply(self, rotation: Rotation, vector: Array) -> Array:
        rot = {
            "quaternion": Rotation.from_quat,
            "roll_pitch_yaw": partial(Rotation.from_euler, EULER_SEQ),
            "DCM": Rotation.from_matrix,
        }[self.rotation_type](rotation)

        if self.inverse:
            rot = rot.inv()

        return rot.apply(vector)

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        vec = self.input_ports[self.vector_input_index].eval(context)

        with ErrorCollector.context(error_collector):
            if vec.shape != (3,):
                raise ShapeMismatchError(
                    system=self,
                    expected_shape=(3,),
                    actual_shape=vec.shape,
                )

        if self.external_rotation:
            rot = self.input_ports[self.rotation_input_index].eval(context)

            with ErrorCollector.context(error_collector):
                if self.rotation_type == "quaternion" and rot.shape != (4,):
                    raise ShapeMismatchError(
                        system=self,
                        expected_shape=(4,),
                        actual_shape=rot.shape,
                    )
                elif self.rotation_type == "roll_pitch_yaw" and rot.shape != (3,):
                    raise ShapeMismatchError(
                        system=self,
                        expected_shape=(3,),
                        actual_shape=rot.shape,
                    )
                elif self.rotation_type == "DCM" and rot.shape != (3, 3):
                    raise ShapeMismatchError(
                        system=self,
                        expected_shape=(3, 3),
                        actual_shape=rot.shape,
                    )

CoordinateRotationConversion

Bases: LeafSystem

Converts between different representations of rotations.

See CoordinateRotation block documentation for descriptions of the different rotation representations supported. This block supports conversion between quaternion, roll-pitch-yaw (Euler angles), and direction cosine matrix (DCM).

Note that conversions are reversible in terms of the abstract rotation, although creating a quaternion from a direction cosine matrix (and therefore creating a quaternion from roll-pitch-yaw sequence) results in an arbitrary sign assignment.

Input ports

(0): The input rotation representation.

Output ports

(1): The output rotation representation.

Parameters:

Name Type Description Default
conversion_type str

The type of rotation conversion to perform. Must be one of ("quaternion_to_euler", "quaternion_to_dcm", "euler_to_quaternion", "euler_to_dcm", "dcm_to_quaternion", "dcm_to_euler")

required
Source code in collimator/library/rotations.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
class CoordinateRotationConversion(LeafSystem):
    """Converts between different representations of rotations.

    See CoordinateRotation block documentation for descriptions of the different
    rotation representations supported. This block supports conversion between
    quaternion, roll-pitch-yaw (Euler angles), and direction cosine matrix (DCM).

    Note that conversions are reversible in terms of the abstract rotation, although
    creating a quaternion from a direction cosine matrix (and therefore creating a
    quaternion from roll-pitch-yaw sequence) results in an arbitrary sign assignment.

    Input ports:
        (0): The input rotation representation.

    Output ports:
        (1): The output rotation representation.

    Parameters:
        conversion_type (str): The type of rotation conversion to perform.
            Must be one of ("quaternion_to_euler", "quaternion_to_dcm",
            "euler_to_quaternion", "euler_to_dcm", "dcm_to_quaternion", "dcm_to_euler")
    """

    @parameters(static=["conversion_type"])
    def __init__(self, conversion_type, **kwargs):
        super().__init__(**kwargs)
        self.declare_input_port()
        self._output_port_idx = self.declare_output_port(requires_inputs=True)

    def initialize(self, conversion_type):
        if conversion_type not in (
            "quaternion_to_RPY",
            "quaternion_to_DCM",
            "RPY_to_quaternion",
            "RPY_to_DCM",
            "DCM_to_quaternion",
            "DCM_to_RPY",
        ):
            message = f"Invalid rotation conversion type: {conversion_type}."
            raise BlockParameterError(
                message=message, system=self, parameter_name="conversion_type"
            )

        _func = {
            "quaternion_to_RPY": quat_to_euler,
            "quaternion_to_DCM": quat_to_dcm,
            "RPY_to_quaternion": euler_to_quat,
            "RPY_to_DCM": euler_to_dcm,
            "DCM_to_quaternion": dcm_to_quat,
            "DCM_to_RPY": dcm_to_euler,
        }[conversion_type]

        def _output(_time, _state, *inputs, **_parameters):
            (u,) = inputs
            return _func(u)

        self.configure_output_port(
            self._output_port_idx,
            _output,
            requires_inputs=True,
        )

        # Serialization
        self.conversion_type = conversion_type

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        rot = self.input_ports[0].eval(context)

        with ErrorCollector.context(error_collector):
            if self.conversion_type in (
                "quaternion_to_RPY",
                "quaternion_to_DCM",
            ) and rot.shape != (4,):
                raise ShapeMismatchError(
                    system=self,
                    expected_shape=(4,),
                    actual_shape=rot.shape,
                )
            elif self.conversion_type in (
                "RPY_to_quaternion",
                "RPY_to_DCM",
            ) and rot.shape != (3,):
                raise ShapeMismatchError(
                    system=self,
                    expected_shape=(3,),
                    actual_shape=rot.shape,
                )
            elif self.conversion_type in (
                "DCM_to_quaternion",
                "DCM_to_RPY",
            ) and rot.shape != (3, 3):
                raise ShapeMismatchError(
                    system=self,
                    expected_shape=(3, 3),
                    actual_shape=rot.shape,
                )

CrossProduct

Bases: ReduceBlock

Compute the cross product between the inputs.

See NumPy docs for details: https://numpy.org/doc/stable/reference/generated/numpy.cross.html

Input ports

(0) The first input vector. (1) The second input vector.

Output ports

(0) The cross product of the inputs.

Source code in collimator/library/primitives.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
class CrossProduct(ReduceBlock):
    """Compute the cross product between the inputs.

    See NumPy docs for details:
    https://numpy.org/doc/stable/reference/generated/numpy.cross.html

    Input ports:
        (0) The first input vector.
        (1) The second input vector.

    Output ports:
        (0) The cross product of the inputs.
    """

    def __init__(self, *args, **kwargs):
        def _cross(inputs):
            return cnp.cross(*inputs)

        super().__init__(2, _cross, *args, **kwargs)

CustomJaxBlock

Bases: LeafSystem

JAX implementation of the PythonScript block.

A few important notes and changes/limitations to this JAX implementation: - For this block all code must be written using the JAX-supported subset of Python: * Numerical operations should use jax.numpy = jnp instead of numpy = np * Standard control flow is not supported (if/else, for, while, etc.). Instead use lax.cond, lax.fori_loop, lax.while_loop, etc. https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives Where possible, NumPy-style operations like jnp.where or jnp.select should be preferred to lax control flow primitives. * Functions must be pure and arrays treated as immutable. https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates Provided these assumptions hold, the code can be JIT compiled, differentiated, run on GPU, etc. - Variable scoping: the init_code and step_code are executed in the same scope, so variables declared in the init_code will be available in the step_code and can be modified in that scope. Internally, everything declared in init_code is treated as a single state-like cache entry. However, variables declared in the step_code will NOT persist between evaluations. Users should think of step_code as a normal Python function where locally declared variables will disappear on leaving the scope. - Persistent variables (outputs and anything declared in init_code) must have static shapes and dtypes. This means that you cannot declare x = 0.0 in init_code and then later assign x = jnp.zeros(4) in step_code.

These changes mean that many older PythonScript blocks may not be backwards compatible.

Input ports

Variable number of input ports, one for each input variable declared in inputs. The order of the input ports is the same as the order of the input variables.

Output ports

Variable number of output ports, one for each output variable declared in outputs. The order of the output ports is the same as the order of the output variables.

Parameters:

Name Type Description Default
dt float

The discrete time step of the block, or None if the block is in agnostic time mode.

None
init_script str

A string containing Python code that will be executed once when the block is initialized. This code can be used to declare persistent variables that will be available in the step_code.

''
user_statements str

A string containing Python code that will be executed once per time step (or per output port evaluation, in agnostic mode). This code can use the persistent variables declared in init_script and the block inputs.

''
finalize_script str

A string containing Python code that will be executed once when the block is finalized. This code can use the persistent variables declared in init_script and the block inputs. (Currently not yet supported).

''
accelerate_with_jax bool

If True, the block will be JIT compiled. If False, the block will be executed in pure Python. This parameter exists for compatibility with UI options; when creating pure Python blocks from code (e.g. for testing), explicitly create the CustomPythonBlock class.

True
time_mode str

One of "discrete" or "agnostic". If "discrete", the block step code will be evaluated at peridodic intervals specified by "dt". If "agnostic", the block step code will be evaluated once per output port evaluation, and the block will not have a discrete time step.

'discrete'
inputs List[str]

A list of input variable names. The order of the input ports is the same as the order of the input variables.

None
outputs Mapping[str, Tuple[DTypeLike, ShapeLike]]

A dictionary mapping output variable names to a tuple of dtype and shape. The order of the output ports is the same as the order of the output variables.

None
static_parameters Mapping[str, Array]

A dictionary mapping parameter names to values. Parameters are treated as immutable and cannot be modified in the step code. Static parameters can't be used in ensemble simulations or optimization workflows.

None
dynamic_parameters Mapping[str, Array]

A dictionary mapping parameter names to values. Parameters are treated as immutable and cannot be modified in the step code. Dynamic parameters can be arrays or scalars, but must have static shapes and dtypes in order to support JIT compilation.

None
Source code in collimator/library/custom.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
class CustomJaxBlock(LeafSystem):
    """JAX implementation of the PythonScript block.

    A few important notes and changes/limitations to this JAX implementation:
    - For this block all code must be written using the JAX-supported subset of Python:
        * Numerical operations should use `jax.numpy = jnp` instead of `numpy = np`
        * Standard control flow is not supported (if/else, for, while, etc.). Instead
            use `lax.cond`, `lax.fori_loop`, `lax.while_loop`, etc.
            https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives
            Where possible, NumPy-style operations like `jnp.where` or `jnp.select` should
            be preferred to lax control flow primitives.
        * Functions must be pure and arrays treated as immutable.
            https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates
        Provided these assumptions hold, the code can be JIT compiled, differentiated,
        run on GPU, etc.
    - Variable scoping: the `init_code` and `step_code` are executed in the same scope,
        so variables declared in the `init_code` will be available in the `step_code`
        and can be modified in that scope. Internally, everything declared in
        `init_code` is treated as a single state-like cache entry.
        However, variables declared in the `step_code` will NOT persist between
        evaluations. Users should think of `step_code` as a normal Python function
        where locally declared variables will disappear on leaving the scope.
    - Persistent variables (outputs and anything declared in `init_code`) must have
        static shapes and dtypes. This means that you cannot declare `x = 0.0` in
        `init_code` and then later assign `x = jnp.zeros(4)` in `step_code`.

    These changes mean that many older PythonScript blocks may not be backwards compatible.

    Input ports:
        Variable number of input ports, one for each input variable declared in `inputs`.
        The order of the input ports is the same as the order of the input variables.

    Output ports:
        Variable number of output ports, one for each output variable declared in `outputs`.
        The order of the output ports is the same as the order of the output variables.

    Parameters:
        dt (float): The discrete time step of the block, or None if the block is
            in agnostic time mode.
        init_script (str): A string containing Python code that will be executed
            once when the block is initialized. This code can be used to declare
            persistent variables that will be available in the `step_code`.
        user_statements (str): A string containing Python code that will be executed
            once per time step (or per output port evaluation, in agnostic mode).
            This code can use the persistent variables declared in `init_script` and
            the block inputs.
        finalize_script (str): A string containing Python code that will be executed
            once when the block is finalized. This code can use the persistent
            variables declared in `init_script` and the block inputs. (Currently not
            yet supported).
        accelerate_with_jax (bool): If True, the block will be JIT compiled. If False,
            the block will be executed in pure Python.  This parameter exists for
            compatibility with UI options; when creating pure Python blocks from code
            (e.g. for testing), explicitly create the CustomPythonBlock class.
        time_mode (str): One of "discrete" or "agnostic". If "discrete", the block
            step code will be evaluated at peridodic intervals specified by "dt".
            If "agnostic", the block step code will be evaluated once per output
            port evaluation, and the block will not have a discrete time step.
        inputs (List[str]): A list of input variable names. The order of the input
            ports is the same as the order of the input variables.
        outputs (Mapping[str, Tuple[DTypeLike, ShapeLike]]): A dictionary mapping
            output variable names to a tuple of dtype and shape. The order of the
            output ports is the same as the order of the output variables.
        static_parameters (Mapping[str, Array]): A dictionary mapping parameter names to
            values. Parameters are treated as immutable and cannot be modified in
            the step code. Static parameters can't be used in ensemble simulations or
            optimization workflows.
        dynamic_parameters (Mapping[str, Array]): A dictionary mapping parameter names to
            values. Parameters are treated as immutable and cannot be modified in
            the step code. Dynamic parameters can be arrays or scalars, but must have static
            shapes and dtypes in order to support JIT compilation.
    """

    @declare_parameters(
        static=[
            "dt",
            "init_script",
            "user_statements",
            "finalize_script",
            "accelerate_with_jax",
            "time_mode",
        ]
    )
    def __init__(
        self,
        dt: float = None,
        init_script: str = "",
        user_statements: str = "",
        finalize_script: str = "",  # presently ignored for JAX block
        accelerate_with_jax: bool = True,
        time_mode: str = "discrete",  # [discrete, agnostic]
        inputs: List[str] = None,  # [name]
        outputs: List[str] = None,
        dynamic_parameters: Mapping[str, Array] = None,
        static_parameters: Mapping[str, Array] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        dynamic_parameters = dynamic_parameters if dynamic_parameters else {}
        static_parameters = static_parameters if static_parameters else {}

        if time_mode not in ["discrete", "agnostic"]:
            raise BlockInitializationError(
                f"Invalid time mode '{time_mode}' for PythonScript block", system=self
            )

        if time_mode == "discrete" and dt is None:
            raise BlockInitializationError(
                "When in discrete time mode, dt is required for block", system=self
            )

        self.time_mode = time_mode

        if inputs is None:
            inputs = []
        if outputs is None:
            outputs = []

        self.dt = dt

        # Note: 'optimize' level could be lowered in debug mode
        try:
            self.init_code = compile(
                init_script, filename="<init>", mode="exec", optimize=2
            )
        except BaseException as e:
            raise PythonScriptError(
                f"Syntax error in init_script for PythonScript block '{self.name_path_str}': {e}",
                system=self,
            ) from e

        try:
            self.step_code = compile(
                user_statements, filename="<step>", mode="exec", optimize=2
            )
        except BaseException as e:
            raise PythonScriptError(
                f"Syntax error in user_statements for PythonScript block '{self.name_path_str}': {e}",
                system=self,
            ) from e

        if finalize_script != "" and not isinstance(self, CustomPythonBlock):
            raise PythonScriptError(
                f"PythonScript block '{self.name_path_str}' has finalize_script "
                "but this is not supported at the moment.",
                system=self,
                parameter_name="finalize_script",
            )

        # Declare parameters
        for param_name, value in dynamic_parameters.items():
            if isinstance(value, list):
                value = cnp.asarray(value)
            as_array = isinstance(value, cnp.ndarray) or cnp.isscalar(value)
            self.declare_dynamic_parameter(param_name, value, as_array=as_array)

        for param_name, value in static_parameters.items():
            self.declare_static_parameter(param_name, value)

        # Run the init_script
        persistent_env = self.exec_init()

        # Declare an input port for each of the input variables
        self.input_names = inputs
        for name in inputs:
            self.declare_input_port(name)

        # Declare a cache component for each of the output variables
        self._create_cache_type(outputs)

        if time_mode == "discrete":
            self._configure_discrete(dt, outputs, persistent_env)
        else:
            self._configure_agnostic(outputs, persistent_env)

    def initialize(
        self,
        dt: float = None,
        init_script: str = "",
        user_statements: str = "",
        finalize_script: str = "",  # presently ignored for JAX block
        accelerate_with_jax: bool = True,
        time_mode: str = "discrete",  # [discrete, agnostic]
        **parameters,
    ):
        pass

    def _initialize_outputs(self, outputs, persistent_env):
        default_outputs = {name: None for name in outputs}

        for name in outputs:
            # If the initial value is set explicitly in the init script,
            # override the default value.  We don't need to do this for
            # agnostic configuration since the outputs will be calculated
            # every evaluation anyway.
            if name in persistent_env:
                value = cnp.asarray(persistent_env[name])
                default_outputs[name] = value

                # Also update the persistent environment so that the data types
                # are consistent with the state.
                persistent_env[name] = value

            # Otherwise throw an error, since we don't know what the initial values
            # should be, or even what shape/dtype they should have.
            else:
                msg = (
                    f"Output variable '{name}' not explicitly initialized in "
                    "init_script for PythonScript block in 'Discrete' time mode. "
                    "Either initialize the variable as an array with the correct "
                    "shape and dtype, or make the block time mode 'Agnostic'."
                )
                raise PythonScriptError(message=msg, system=self)

        return self.CacheType(
            persistent_env=persistent_env,
            **default_outputs,
        )

    def _configure_discrete(self, dt, outputs, persistent_env):
        default_values = self._initialize_outputs(outputs, persistent_env)

        # The step function acts as a periodic update that will update all components
        # of the discrete state.
        self.step_callback_index = self.declare_cache(
            self.exec_step,
            period=dt,
            offset=dt,
            requires_inputs=True,
            default_value=default_values,
        )

        cache = self.callbacks[self.step_callback_index]

        # Get the index into the state cache (different in general from the index
        # into the callback list, since not all callbacks are cached).
        self.step_cache_index = cache.cache_index

        def _make_callback(o_port_name):
            def _output(time, state, *inputs, **parameters):
                return getattr(state.cache[self.step_cache_index], o_port_name)

            return _output

        # Declare output ports for each state variable
        for o_port_name in outputs:
            self.declare_output_port(
                _make_callback(o_port_name),
                name=o_port_name,
                prerequisites_of_calc=[cache.ticket],
                requires_inputs=False,
                period=dt,
                offset=0.0,
            )

    def _configure_agnostic(self, outputs, persistent_env):
        # Create a callback to evaluate the step code and extract the
        # output. Note that this is inefficient since the step code will
        # be evaluated once _for each output port_, but it's the only way
        # to do this unless (until) we implement some variety of block
        # or function pre-ordering.
        def _make_callback(o_port_name):
            def _output(time, state, *inputs, **parameters):
                xd = self.exec_step(time, state, *inputs, **parameters)
                return getattr(xd, o_port_name)

            return _output

        # Declare output ports for each state variable
        for o_port_name in outputs:
            self.declare_output_port(
                jit(_make_callback(o_port_name)),
                name=o_port_name,
                requires_inputs=True,
            )

        # This callback doesn't need to do anything since it's never
        # actually called - the cache here just stores the initial environment
        # and the output ports are evaluated directly.  This should be changed
        # to avoid re-evaluation with multiple output ports once we can do full
        # function ordering.
        def _cache_callback(time, state, *inputs, **parameters):
            return state.cache[self.step_cache_index]

        # Since this is the return type for `exec_step` we have to declare all
        # the output ports as entries in the namedtuple, even though those values
        # won't actually be cached in "agnostic" time mode.  This is just so that
        # both "discrete" and "agnostic" modes can share the same code.
        default_values = self.CacheType(
            persistent_env=persistent_env,
            **{o_port_name: None for o_port_name in outputs},
        )
        self.step_callback_index = self.declare_cache(
            _cache_callback,
            default_value=default_values,
            requires_inputs=False,
            prerequisites_of_calc=[inport.ticket for inport in self.input_ports],
        )

        cache = self.callbacks[self.step_callback_index]
        self.step_cache_index = cache.cache_index

    def _create_cache_type(self, outputs):
        # Store the output ports as a name for type inference and casting
        self.output_names = outputs

        # Also store the dictionary of local environment variables as a cache entry
        # This is the only persistent state of the system (besides outputs) - anything
        # declared in the "step" function will be forgotten at the end of the step

        self.CacheType = namedtuple("CacheType", self.output_names + ["persistent_env"])

    @property
    def local_env_base(self):
        # Define a starting point for the local code execution environment.
        # we have to inclide __main__ so that the code behaves like a module.
        # this allows for code like this:
        #   imports ...
        #   a = 1
        #   def f(b):
        #       return a+b
        #   out_0 = f(2)
        #
        # without getting a 'a not defined' error.
        return {
            "__main__": {},
        }

    def exec_init(self) -> dict[str, Array]:
        # Before executing the step code, we have to build up the local environment.
        # This includes specified modules, python block user defined parameters.

        default_parameters = {
            name: param.get() for name, param in self.dynamic_parameters.items()
        }

        local_env = {
            **self.local_env_base,
            **default_parameters,
        }

        # similar to above where we included __main__ so the code behaves as a module,
        # here we have to pass the local_env with __main__ as 1] globals, since that
        # is what allow the code to be executed as a module. 2] local since that is where
        # the new bindings will be written, that we need to retain since the code in step_code
        # may depend on these bindings.
        try:
            _default_exec(
                self.init_code,
                local_env,
                logger_=logger,
                system=self,
                code_name="init",
            )

        except BaseException as e:
            logger.error(
                "PythonScript block '%s' init script failed",
                self.name_path_str,
                **logdata(block=self),
            )
            raise PythonScriptError(system=self) from e

        # persistent_env contains bindings for parameters and for values from init_script
        persistent_env, static_env = _filter_non_traceable(local_env)

        # Since this is called during block initialization and not any JIT-compiled code,
        # we can safely store any untraceable variables as block attributes.  For example,
        # this may contain custom functions, classes, etc.
        self.static_env = static_env

        return persistent_env

    def exec_step(self, time: float, state: LeafState, *inputs, **parameters):
        # Before executing the step code, we have to build up the local environment.
        # This includes the persistent variables (anything declared in `init_code`),
        # time, block inputs, user-defined parameters, and specified modules.

        # Retrieve the variables declared in `init_code` from the discrete state
        full_env = state.cache[self.step_cache_index]
        persistent_env = full_env.persistent_env

        # Inputs are in order of port declaration, so they match `self.input_names`
        input_env = dict(zip(self.input_names, inputs))

        # Create a dictionary of all the information that the step function will need
        base_copy = self.local_env_base.copy()
        local_env = {
            **self.static_env,
            **base_copy,
            **persistent_env,
            **input_env,
            **parameters,
        }

        # Execute the step code in the local environment
        try:
            _default_exec(
                self.step_code,
                local_env,
                logger_=logger,
                inputs=input_env,
                system=self,
                code_name="step",
            )

        except PythonScriptError:
            raise
        except BaseException as e:
            logger.error(
                "PythonScript block '%s' step failed.",
                self.name_path_str,
                **logdata(block=self),
            )
            raise PythonScriptError(system=self) from e

        # Updated state variables are stored in the local environment
        xd = {name: local_env[name] for name in self.output_names}

        # Store the persistent variables in the corresponding discrete state
        xd["persistent_env"] = {key: local_env[key] for key in persistent_env}

        # Make sure the results have a consistent data type
        for name in self.output_names:
            xd[name] = cnp.asarray(local_env[name])

            # Also make sure the value stored in the persistent environment
            # has the same data type
            if name in persistent_env:
                xd["persistent_env"][name] = xd[name]

        return self.CacheType(**xd)

    def check_types(
        self,
        context: ContextBase,
        error_collector: ErrorCollector = None,
    ):
        """Test-compile the init and step code to check for errors."""
        try:
            # Note that exec_step doesn't use parameters or time
            inputs = self.collect_inputs(context)
            jit(self.exec_step)(None, context[self.system_id].state, *inputs)
        except BaseException as exc:
            with ErrorCollector.context(error_collector):
                name_error = _caused_by_nameerror(exc)
                if name_error and name_error.name == "time":
                    raise PythonScriptTimeNotSupportedError(system=self) from exc
                if isinstance(exc, PythonScriptError):
                    raise
                raise PythonScriptError(system=self) from exc

check_types(context, error_collector=None)

Test-compile the init and step code to check for errors.

Source code in collimator/library/custom.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def check_types(
    self,
    context: ContextBase,
    error_collector: ErrorCollector = None,
):
    """Test-compile the init and step code to check for errors."""
    try:
        # Note that exec_step doesn't use parameters or time
        inputs = self.collect_inputs(context)
        jit(self.exec_step)(None, context[self.system_id].state, *inputs)
    except BaseException as exc:
        with ErrorCollector.context(error_collector):
            name_error = _caused_by_nameerror(exc)
            if name_error and name_error.name == "time":
                raise PythonScriptTimeNotSupportedError(system=self) from exc
            if isinstance(exc, PythonScriptError):
                raise
            raise PythonScriptError(system=self) from exc

CustomPythonBlock

Bases: CustomJaxBlock

Container for arbitrary user-defined Python code.

Implemented to support legacy PythonScript blocks.

Not traceable (no JIT compilation or autodiff). The internal implementation and behavior of this block differs vastly from the JAX-compatible block as this block stores state directly within the Python instance. Objects and modules can be kept as discrete state.

Note that in "agnostic" mode, the step code will be evaluated once per output port evaluation. Because locally defined environment variables (in the init script) are preserved between evaluations, any mutation of these variables will be preserved. This can lead to unexpected behavior and should be avoided. Stateful behavior should be implemented using discrete state variables instead.

Source code in collimator/library/custom.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
class CustomPythonBlock(CustomJaxBlock):
    """Container for arbitrary user-defined Python code.

    Implemented to support legacy PythonScript blocks.

    Not traceable (no JIT compilation or autodiff). The internal implementation
    and behavior of this block differs vastly from the JAX-compatible block as
    this block stores state directly within the Python instance. Objects
    and modules can be kept as discrete state.

    Note that in "agnostic" mode, the step code will be evaluated _once per
    output port evaluation_. Because locally defined environment variables
    (in the init script) are preserved between evaluations, any mutation of
    these variables will be preserved. This can lead to unexpected behavior
    and should be avoided. Stateful behavior should be implemented using
    discrete state variables instead.
    """

    __exec_fn = _default_exec

    def __init__(
        self,
        dt: float = None,
        init_script: str = "",
        user_statements: str = "",
        finalize_script: str = "",  # presently ignored
        inputs: List[str] = None,  # [name]
        outputs: List[str] = None,
        accelerate_with_jax: bool = False,
        time_mode: str = "discrete",
        static_parameters: Mapping[str, Array] = None,
        **kwargs,
    ):
        self._static_data_initialized = False
        self._parameters = static_parameters or {}
        self._persistent_env = {}

        # Will populate return type information during static initialization
        self.result_shape_dtypes = None
        self.return_dtypes = None

        super().__init__(
            dt=dt,
            init_script=init_script,
            user_statements=user_statements,
            finalize_script=finalize_script,
            inputs=inputs,
            outputs=outputs,
            accelerate_with_jax=accelerate_with_jax,
            time_mode=time_mode,
            static_parameters=self._parameters,
            **kwargs,
        )

        if time_mode == "agnostic" and cnp.active_backend == "jax":
            logger.warning(
                "System %s is in agnostic time mode but is not traced with JAX. Be "
                "advised that the step code will be evaluated once per output port "
                "evaluation. Any mutation of the local environment should be strictly "
                "avoided as it will likely lead to unexpected behavior.",
                self.name_path_str,
            )

    def initialize(self, **kwargs):
        pass

    @property
    def has_feedthrough_side_effects(self) -> bool:
        # See explanation in `SystemBase.has_ode_side_effects`.
        return self.time_mode == "agnostic"

    @staticmethod
    def set_exec_fn(exec_fn: callable):
        CustomPythonBlock.__exec_fn = exec_fn

    @property
    def local_env_base(self):
        # Define a starting point for the local code execution environment.
        return {
            "__main__": {},
            "true": True,
            "false": False,
        }

    def exec_init(self) -> None:
        default_parameters = {
            name: param.get() for name, param in self.dynamic_parameters.items()
        }

        local_env = {
            **self.local_env_base,
            **self._parameters,
            **default_parameters,
        }

        exec_fn = functools.partial(
            CustomPythonBlock.__exec_fn,
            code=self.init_code,
            env=local_env,
            logger_=logger,
            system=self,
            code_name="init",
        )

        try:
            io_callback(exec_fn, None)
        except KeyboardInterrupt as e:
            logger.error(
                "Python block '%s' init script execution was interrupted.",
                self.name,
                **logdata(block=self),
            )
            raise PythonScriptError(
                message="Python block init script execution was interrupted.",
                system=self,
            ) from e
        except PythonScriptError as e:
            logger.error("%s: init script failed.", self.name, **logdata(block=self))
            raise e
        except BaseException as e:
            logger.error("%s: init script failed.", self.name, **logdata(block=self))
            raise PythonScriptError(system=self) from e
        self._persistent_env = local_env

        return None

    def exec_step(self, time, state, *inputs, **parameters):
        if not self._static_data_initialized:
            # return_dtypes is inferred in initialize_static_data()
            raise PythonScriptError(
                "Trying to execute step code before static data has been initialized",
                system=self,
            )
        logger.debug(
            "Executing step for %s with state=%s, inputs=%s",
            self.name,
            state,
            inputs,
        )

        # Inputs are in order of port declaration, so they match `self.input_names`
        input_env = dict(zip(self.input_names, inputs))

        base_copy = self.local_env_base.copy()
        local_env = {
            **base_copy,
            **self._persistent_env,
            **parameters,
        }

        exec_fn = functools.partial(
            CustomPythonBlock.__exec_fn,
            code=self.step_code,
            env=local_env,
            logger_=logger,
            return_vars=self.output_names,
            return_dtypes=self.return_dtypes,
            system=self,
            code_name="step",
        )

        def wrapped_exec_fn(inputs):
            try:
                return exec_fn(inputs=inputs)
            except KeyboardInterrupt:
                logger.error(
                    "Python block '%s' step script execution was interrupted.",
                    self.name,
                    **logdata(block=self),
                )
                raise
            except NameError as e:
                err_msg = (
                    f"Python block '{self.name}' step script execution failed with a NameError on"
                    + f" missing variable '{e.name}'."
                    + " All names used in this script should be declared in the init script."
                    + f" The execution environment contains the following names: {', '.join(list(local_env.keys()))}"
                )
                logger.error(err_msg)
                logger.error("NameError: %s", e, **logdata(block=self))
                raise PythonScriptError(system=self) from e
            except PythonScriptError as e:
                logger.error("%s: exec_step failed.", self.name, **logdata(block=self))
                raise e
            except BaseException as e:
                logger.error("%s: exec_step failed.", self.name, **logdata(block=self))
                raise PythonScriptError(system=self) from e

        return_vars = io_callback(
            wrapped_exec_fn,
            self.result_shape_dtypes,
            inputs=input_env,
        )

        # Keep local env for next step but only if defined in init_script
        # NOTE: If this restriction turns out to be counterproductive, we can
        # remove it and remove the NameError handling above as well. The thinking
        # here is that this could help avoiding stuff like `if time == 0: x = 0`
        # See https://collimator.atlassian.net/browse/WC-98
        self._persistent_env = {
            key: local_env[key] for key in self._persistent_env if key in local_env
        }

        # Updated state variables are stored in the local environment
        xd = {name: return_vars[i] for i, name in enumerate(self.output_names)}

        return self.CacheType(persistent_env=None, **xd)

    def _initialize_outputs(self, outputs, _persistent_env):
        # Override the base implemenetation since `persistent_env` will be None
        # in this case. Instead, pass the class attribute where the environment
        # is actually maintained.
        default_outputs = {name: None for name in outputs}
        default_values = self.CacheType(
            persistent_env=self._persistent_env,
            **default_outputs,
        )
        default_values = super()._initialize_outputs(outputs, self._persistent_env)
        default_outputs = default_values._asdict()
        self._persistent_env = default_outputs.pop("persistent_env")

        # Determine return data types
        self._initialize_result_shape_dtypes(
            [default_outputs[output] for output in outputs]
        )

        return self.CacheType(
            persistent_env=None,
            **default_outputs,
        )

    def _initialize_result_shape_dtypes(self, outputs):
        self.result_shape_dtypes = []
        self.return_dtypes = []
        for value in outputs:
            self.result_shape_dtypes.append(
                jax.ShapeDtypeStruct(value.shape, value.dtype)
            )
            self.return_dtypes.append(value.dtype)

    def initialize_static_data(self, context):
        # If in agnostic mode, call the step function once to determine the
        # data types and then store those in result_shape_dtype and return_dtypes.
        context = LeafSystem.initialize_static_data(self, context)

        if self.result_shape_dtypes is not None:
            # These data types are already known (block is in discrete mode)
            self._static_data_initialized = True
            return context

        inputs = self.collect_inputs(context)
        input_env = dict(zip(self.input_names, inputs))

        base_copy = self.local_env_base.copy()
        local_env = {
            **base_copy,
            **self._persistent_env,
        }

        # Will not do any type conversion
        return_dtypes = [None for _ in self.output_names]

        exec_fn = functools.partial(
            CustomPythonBlock.__exec_fn,
            self.step_code,
            local_env,
            logger_=logger,
            return_vars=self.output_names,
            return_dtypes=return_dtypes,
            system=self,
            code_name="step",
        )

        return_vars = exec_fn(inputs=input_env)

        self._initialize_result_shape_dtypes(return_vars)

        self._static_data_initialized = True

        return context

    def check_types(
        self,
        context: ContextBase,
        error_collector=None,
    ):
        pass

DataSource

Bases: SourceBlock

Produces outputs from an imported .csv file.

The block's output(s) must be synchronized with simulation time. This can be achieved by two mechanisms:

  1. Each data row in the file is accompanied by a time value. The time value for each row is provided as a column in the data file. For this option, the values in the time column must be strictly increasing, with no duplicates, from the first data row to the last. The block will check that this condition is satisfied at compile time. The column with the time values is identified by the column index. This option assumes the left most column is index 0, counting up to the right. to select this option, set Time samples as column to True, and provide the index of the column.

  2. The time value for each data row is defined using a fixed time step between each row. For this option, the Sampling parameter defines the time step. The block then computes the time values for each data row starting with zero for the first row. Note that by definition, this results in a strictly increasing set. To select this option, set time_samples_as_column to False, and provide the sampling_interval value.

When block output(s) are requested at a simulation time that falls between time values for adjacent data rows, there are two options for how the block should compute the interpolation:

  1. Zero Order Hold: the block returns data from the row with the lower time value.

  2. Linear: the block performs a linear interpolation between the lower and higher time value data rows.

There are several mechanism for selecting which data columns are included in the block output(s). All options are applied using the data_columns parameter:

  1. Column name: enter a string that matches a column name in the header. For this option, header_as_first_row must be set to True. For this option, it is only possible to select a single column for the output. The block will output a scalar.

  2. Column index: enter an integer index for the desired column. This option again assumes the left most column is index 0, counting up to the right. This option assumes the same column index regardless of of whether time_samples_as_column is True or False, therefore it is possible to select the same column for time and output. With this option, the block will output a scalar.

  3. Slice: enter a slice used to identify a set of sequential columns to be used as the desired data for output. The slice works like a NumPy slice. For example, if the file has 10 columns, 3:8 will results in the block returning a vector of length 5, containing, in order, columns 3,4,5,6,7. Note that like NumPy, the second integer in the slice is excluded in the set of indices. Only positive integers are allowed for the slice (e.g. 2:-1, -3:-1, and 3: are not allowed).

Presently, there is only one option for extrapolation beyond the end of data in the file. The block will have reached the end of data if the simulation time is greater than the time value for the last row of data. Once this occurs, the block output(s) will be the values in the last row of data.

Parameters:

Name Type Description Default
file_name str

The name of the imported file which contains the data.

required
header_as_first_row bool

Check this box if the first row is meant to be a header.

False
time_samples_as_column bool

Check this box to select a column form the file to use as the time values. Uncheck it to provide time as a fixed time step between rows.

False
time_column str

Only used when time_samples_as_column is True. This is the index of the column to be used as time.

'0'
sampling_interval float

only used when time_samples_as_column is False. Provide the fixed time step value here.

1.0
data_columns str

Enter name, index, or slice to select columns from the data file.

'1'
extrapolation str

the extrapolation method. One of "hold" or "zero".

'hold'
interpolation str

the interpolation method. One of "zero_order_hold" or "linear".

'zero_order_hold'
Source code in collimator/library/data_source.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
class DataSource(SourceBlock):
    """Produces outputs from an imported .csv file.

    The block's output(s) must be synchronized with simulation time. This can be
    achieved by two mechanisms:

    1. Each data row in the file is accompanied by a time value. The time value
        for each row is provided as a column in the data file. For this option,
        the values in the time column must be strictly increasing, with no duplicates,
        from the first data row to the last. The block will check that this condition
        is satisfied at compile time. The column with the time values is identified by
        the column index. This option assumes the left most column is index 0, counting
        up to the right. to select this option, set Time samples as column to True, and
        provide the index of the column.

    2. The time value for each data row is defined using a fixed time step between
        each row. For this option, the Sampling parameter defines the time step.
        The block then computes the time values for each data row starting with zero
        for the first row. Note that by definition, this results in a strictly
        increasing set. To select this option, set `time_samples_as_column` to False,
        and provide the `sampling_interval` value.

    When block output(s) are requested at a simulation time that falls between time
    values for adjacent data rows, there are two options for how the block should
    compute the interpolation:

    1. Zero Order Hold: the block returns data from the row with the lower time value.

    2. Linear: the block performs a linear interpolation between the lower and higher
        time value data rows.

    There are several mechanism for selecting which data columns are included in the
    block output(s). All options are applied using the `data_columns` parameter:

    1. Column name: enter a string that matches a column name in the header. For
        this option, `header_as_first_row` must be set to True. For this option, it
        is only possible to select a single column for the output. The block will
        output a scalar.

    2. Column index: enter an integer index for the desired column. This option
        again assumes the left most column is index 0, counting up to the right. This
        option assumes the same column index regardless of of whether
        `time_samples_as_column` is True or False, therefore it is possible to select
        the same column for time and output. With this option, the block will output
        a scalar.

    3. Slice: enter a slice used to identify a set of sequential columns to be used
        as the desired data for output. The slice works like a NumPy slice. For
        example, if the file has 10 columns, `3:8` will results in the block returning
        a vector of length 5, containing, in order, columns 3,4,5,6,7. Note that
        like NumPy, the second integer in the slice is excluded in the set of
        indices. Only positive integers are allowed for the slice (e.g. `2:-1`,
        `-3:-1`, and `3:` are not allowed).

    Presently, there is only one option for extrapolation beyond the end of data in
    the file. The block will have reached the end of data if the simulation time is
    greater than the time value for the last row of data. Once this occurs, the block
    output(s) will be the values in the last row of data.

    Parameters:
        file_name:
            The name of the imported file which contains the data.
        header_as_first_row:
            Check this box if the first row is meant to be a header.
        time_samples_as_column:
            Check this box to select a column form the file to use as the time values.
            Uncheck it to provide time as a fixed time step between rows.
        time_column:
            Only used when `time_samples_as_column` is True. This is the index of
            the column to be used as time.
        sampling_interval: only used when `time_samples_as_column` is False. Provide
            the fixed time step value here.
        data_columns:
            Enter name, index, or slice to select columns from the data file.
        extrapolation: the extrapolation method.  One of "hold" or "zero".
        interpolation: the interpolation method.  One of "zero_order_hold" or "linear".
    """

    @parameters(
        static=[
            "file_name",
            "data_columns",
            "extrapolation",
            "header_as_first_row",
            "interpolation",
            "sampling_interval",
            "time_column",
            "time_samples_as_column",
        ]
    )
    def __init__(
        self,
        file_name: str,
        data_columns: str = "1",  # slice, e.g. 3:4
        extrapolation: str = "hold",
        header_as_first_row: bool = False,
        interpolation: str = "zero_order_hold",
        sampling_interval: float = 1.0,
        time_column: str = "0",  # @am. could be an int
        time_samples_as_column: bool = False,
        **kwargs,
    ):
        # FIXME: move to block_interface.py
        kwargs.pop("data_integration_id", None)

        super().__init__(self._callback, **kwargs)

        times, data = load_csv(
            str(file_name),
            str(data_columns),
            bool(header_as_first_row),
            float(sampling_interval),
            str(time_column),
            bool(time_samples_as_column),
        )

        times = cnp.array(times)
        data = cnp.array(data)

        if data.size == 0:
            raise ValueError(
                f"DataSource {self.name_path_strme} could not get the requested data columns."
            )

        max_i_zoh = len(times) - 1
        max_i_interp = len(times) - 2
        output_dim = data.shape[1]
        self._scalar_output = output_dim == 1

        def get_below_row_idx(time, max_i):
            """
            first we clip the value of 'time' so that it falls inside the
            range of 'times'. this ensures we dont get strange extrapolation behavior.
            then, find the index of 'times' row value that is largest but still smaller
            than 'time'. we use this to locate the rows in 'times' that bound 'time'.
            """
            time_clipped = cnp.clip(time, times[0], times[-1])
            index = cnp.searchsorted(times[: max_i + 1], time_clipped, side="right")
            return index - 1, time_clipped

        def _func_zoh(time):
            i, _ = get_below_row_idx(time, max_i_zoh)
            if extrapolation != "zero":
                return data[i, :]
            return cnp.where(time > times[-1], cnp.zeros(output_dim), data[i, :])

        def _func_interp(time):
            """
            the second lambda function does this:
            y = (yp2-yp1)/(xp2-xp1)*(x-xp1) + yp1
            but does so by operating on the arrays
            ap1 and ap2 which provide the yp1's and yp2's.
            the xp1's and xp2's are time values.
            """
            i, time_clipped = get_below_row_idx(time, max_i_interp)
            ap1 = data[i, :]
            ap2 = data[i + 1, :]

            if extrapolation != "zero":
                return (ap2 - ap1) / (times[i + 1] - times[i]) * (
                    time_clipped - times[i]
                ) + ap1

            return cnp.where(
                time > times[-1],
                cnp.zeros(output_dim),
                (ap2 - ap1) / (times[i + 1] - times[i]) * (time_clipped - times[i])
                + ap1,
            )

        # wrap output function to return scalar when only one column selected.
        def _wrap_func(_func):
            def _ds_wrapped_func(time):
                output = _func(time)
                return output[0]

            return _ds_wrapped_func

        if interpolation == "zero_order_hold":
            _func = _func_zoh
        else:
            _func = _func_interp

        if self._scalar_output:
            _func = _wrap_func(_func)

        # Call JIT to massively improve the performance, especially when
        # calling create_context/check_types... including when backend is numpy.
        self._func = cnp.jit(_func)

    def _callback(self, time):
        return self._func(time)

__init__(file_name, data_columns='1', extrapolation='hold', header_as_first_row=False, interpolation='zero_order_hold', sampling_interval=1.0, time_column='0', time_samples_as_column=False, **kwargs)

Source code in collimator/library/data_source.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
@parameters(
    static=[
        "file_name",
        "data_columns",
        "extrapolation",
        "header_as_first_row",
        "interpolation",
        "sampling_interval",
        "time_column",
        "time_samples_as_column",
    ]
)
def __init__(
    self,
    file_name: str,
    data_columns: str = "1",  # slice, e.g. 3:4
    extrapolation: str = "hold",
    header_as_first_row: bool = False,
    interpolation: str = "zero_order_hold",
    sampling_interval: float = 1.0,
    time_column: str = "0",  # @am. could be an int
    time_samples_as_column: bool = False,
    **kwargs,
):
    # FIXME: move to block_interface.py
    kwargs.pop("data_integration_id", None)

    super().__init__(self._callback, **kwargs)

    times, data = load_csv(
        str(file_name),
        str(data_columns),
        bool(header_as_first_row),
        float(sampling_interval),
        str(time_column),
        bool(time_samples_as_column),
    )

    times = cnp.array(times)
    data = cnp.array(data)

    if data.size == 0:
        raise ValueError(
            f"DataSource {self.name_path_strme} could not get the requested data columns."
        )

    max_i_zoh = len(times) - 1
    max_i_interp = len(times) - 2
    output_dim = data.shape[1]
    self._scalar_output = output_dim == 1

    def get_below_row_idx(time, max_i):
        """
        first we clip the value of 'time' so that it falls inside the
        range of 'times'. this ensures we dont get strange extrapolation behavior.
        then, find the index of 'times' row value that is largest but still smaller
        than 'time'. we use this to locate the rows in 'times' that bound 'time'.
        """
        time_clipped = cnp.clip(time, times[0], times[-1])
        index = cnp.searchsorted(times[: max_i + 1], time_clipped, side="right")
        return index - 1, time_clipped

    def _func_zoh(time):
        i, _ = get_below_row_idx(time, max_i_zoh)
        if extrapolation != "zero":
            return data[i, :]
        return cnp.where(time > times[-1], cnp.zeros(output_dim), data[i, :])

    def _func_interp(time):
        """
        the second lambda function does this:
        y = (yp2-yp1)/(xp2-xp1)*(x-xp1) + yp1
        but does so by operating on the arrays
        ap1 and ap2 which provide the yp1's and yp2's.
        the xp1's and xp2's are time values.
        """
        i, time_clipped = get_below_row_idx(time, max_i_interp)
        ap1 = data[i, :]
        ap2 = data[i + 1, :]

        if extrapolation != "zero":
            return (ap2 - ap1) / (times[i + 1] - times[i]) * (
                time_clipped - times[i]
            ) + ap1

        return cnp.where(
            time > times[-1],
            cnp.zeros(output_dim),
            (ap2 - ap1) / (times[i + 1] - times[i]) * (time_clipped - times[i])
            + ap1,
        )

    # wrap output function to return scalar when only one column selected.
    def _wrap_func(_func):
        def _ds_wrapped_func(time):
            output = _func(time)
            return output[0]

        return _ds_wrapped_func

    if interpolation == "zero_order_hold":
        _func = _func_zoh
    else:
        _func = _func_interp

    if self._scalar_output:
        _func = _wrap_func(_func)

    # Call JIT to massively improve the performance, especially when
    # calling create_context/check_types... including when backend is numpy.
    self._func = cnp.jit(_func)

DeadZone

Bases: FeedthroughBlock

Generates zero output within a specified range.

Applies the following function:

         [ input,       input < -half_range
output = | 0,           -half_range <= input <= half_range
         [ input        input > half_range

Parameters:

Name Type Description Default
half_range

The range of the dead zone. Must be > 0.

1.0
Input ports

(0) The input signal.

Output ports

(0) The input signal modified by the dead zone.

Events

An event is triggered when the signal enters or exits the dead zone in either direction.

Source code in collimator/library/primitives.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
class DeadZone(FeedthroughBlock):
    """Generates zero output within a specified range.

    Applies the following function:
    ```
             [ input,       input < -half_range
    output = | 0,           -half_range <= input <= half_range
             [ input        input > half_range
    ```

    Parameters:
        half_range: The range of the dead zone.  Must be > 0.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The input signal modified by the dead zone.

    Events:
        An event is triggered when the signal enters or exits the dead zone
        in either direction.
    """

    @parameters(dynamic=["half_range"])
    def __init__(self, half_range=1.0, **kwargs):
        super().__init__(self._dead_zone, **kwargs)
        if half_range <= 0:
            raise BlockParameterError(
                message=f"DeadZone block {self.name} has invalid half_range {half_range}. Must be > 0.",
                system=self,
                parameter_name="half_range",
            )

    def initialize(self, half_range):
        pass

    def _dead_zone(self, x, **params):
        return cnp.where(abs(x) < params["half_range"], x * 0, x)

    def _lower_limit_event_value(self, _time, _state, *inputs, **params):
        (u,) = inputs
        return u + params["half_range"]

    def _upper_limit_event_value(self, _time, _state, *inputs, **params):
        (u,) = inputs
        return u - params["half_range"]

    def initialize_static_data(self, context):
        # Add zero-crossing events so ODE solvers can't try to integrate
        # through a discontinuity.
        if not self.has_zero_crossing_events and (self.output_ports[0]):
            self.declare_zero_crossing(
                self._lower_limit_event_value, direction="crosses_zero"
            )
            self.declare_zero_crossing(
                self._upper_limit_event_value, direction="crosses_zero"
            )

        return super().initialize_static_data(context)

Demultiplexer

Bases: LeafSystem

Split a vector signal into its components.

Input ports

(0) The vector signal to split.

Output ports

(0..n_out-1) The components of the input signal.

Source code in collimator/library/primitives.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
class Demultiplexer(LeafSystem):
    """Split a vector signal into its components.

    Input ports:
        (0) The vector signal to split.

    Output ports:
        (0..n_out-1) The components of the input signal.
    """

    def __init__(self, n_out, **kwargs):
        super().__init__(**kwargs)

        self.declare_input_port()

        # Need a helper function so that the lambda captures the correct value of i
        # and doesn't use something that ends up fixed in scope.
        def _declare_output(i):
            def _compute_output(_time, _state, *inputs, **_params):
                (input_vec,) = inputs
                return input_vec[i]

            self.declare_output_port(
                _compute_output,
                prerequisites_of_calc=[self.input_ports[0].ticket],
            )

        for i in cnp.arange(n_out):
            _declare_output(i)

Derivative

Bases: LTISystem

Causal estimate of the derivative of a signal in continuous time.

This is implemented as a state-space system with matrices (A, B, C, D), which are then used to create a (first-order) LTISystem. Note that this only supports single-input, single-output derivative blocks.

The derivative is implemented as a filter with a filter coefficient of N, which is used to construct the following proper transfer function:

    H(s) = Ns / (s + N)

As N -> ∞, the transfer function approaches a pure differentiator. However, this system becomes increasingly stiff and difficult to integrate, so it is recommended to select a value of N based on the time scales of the system.

From the transfer function, scipy.signal.tf2ss is used to convert to state-space form and create an LTISystem.

Input ports

(0) u: Input (scalar)

Output ports

(0) y: Output (scalar), estimating the time derivative du/dt

Source code in collimator/library/linear_system.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
class Derivative(LTISystem):
    """Causal estimate of the derivative of a signal in continuous time.

    This is implemented as a state-space system with matrices (A, B, C, D),
    which are then used to create a (first-order) LTISystem.  Note that this
    only supports single-input, single-output derivative blocks.

    The derivative is implemented as a filter with a filter coefficient of `N`,
    which is used to construct the following proper transfer function:
    ```
        H(s) = Ns / (s + N)
    ```
    As N -> ∞, the transfer function approaches a pure differentiator.  However,
    this system becomes increasingly stiff and difficult to integrate, so it is
    recommended to select a value of N based on the time scales of the system.

    From the transfer function, `scipy.signal.tf2ss` is used to convert to
    state-space form and create an LTISystem.

    Input ports:
        (0) u: Input (scalar)

    Output ports:
        (0) y: Output (scalar), estimating the time derivative du/dt
    """

    # tf2ss is not implemented in jax.scipy.signal so filter_coefficient can't be
    # a dynamic parameter.
    @parameters(static=["filter_coefficient"])
    def __init__(self, filter_coefficient=100, *args, **kwargs):
        N = filter_coefficient
        num = [N, 0]
        den = [1, N]
        A, B, C, D = signal.tf2ss(num, den)
        super().__init__(A, B, C, D, *args, **kwargs)

    def _eval_output(self, time, state, *inputs, **params):
        return self._eval_output_base(self.C, self.D, state, *inputs)

    def ode(self, time, state, u, **params):
        return super().ode(time, state, u, A=self.A, B=self.B)

    def initialize(self, filter_coefficient, **kwargs):
        N = filter_coefficient
        num = [N, 0]
        den = [1, N]

        A, B, C, D = signal.tf2ss(num, den)
        self._init_state(A, B, C, D)

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        inputs = self.collect_inputs(context)
        (u,) = inputs

        if not cnp.ndim(u) == 0:
            with ErrorCollector.context(error_collector):
                raise StaticError(
                    message="Derivative must have scalar input.",
                    system=self,
                )

DerivativeDiscrete

Bases: LeafSystem

Discrete approximation to the derivative of the input signal w.r.t. time.'

By default the block uses a simple backward difference approximation:

y[k] = (u[k] - u[k-1]) / dt

However, the block can also be configured to use a recursive filter for a better approximation. In this case the filter coefficients are determined by the filter_type and filter_coefficient parameters. The filter is a pair of two-element arrays a and b and the filter equation is:

a0*y[k] + a1*y[k-1] = b0*u[k] + b1*u[k-1]

Denoting the filter_coefficient parameter by N, the following filters are available: - "none": The default, a simple finite difference approximation. - "forward": A filtered forward Euler discretization. The filter is: a = [1, (N*dt - 1)] and b = [N, -N]. - "backward": A filtered backward Euler discretization. The filter is: a = [(1 + N*dt), -1] and b = [N, -N]. - "bilinear": A filtered bilinear transform discretization. The filter is: a = [(2 + N*dt), (-2 + N*dt)] and b = [2*N, -2*N].

Input ports

(0) The input signal.

Output ports

(0) The approximate derivative of the input signal.

Parameters:

Name Type Description Default
dt

The time step of the discrete approximation.

required
filter_type

One of "none", "forward", "backward", or "bilinear". This determines the type of filter used to approximate the derivative. The default is "none", corresponding to a simple backward difference approximation.

'none'
filter_coefficient

The coefficient in the filter (N in the equations above). This is only used if filter_type is not "none". The default is 1.0.

1.0
Source code in collimator/library/primitives.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
class DerivativeDiscrete(LeafSystem):
    """Discrete approximation to the derivative of the input signal w.r.t. time.'

    By default the block uses a simple backward difference approximation:
    ```
    y[k] = (u[k] - u[k-1]) / dt
    ```
    However, the block can also be configured to use a recursive filter for a
    better approximation. In this case the filter coefficients are determined
    by the `filter_type` and `filter_coefficient` parameters. The filter is
    a pair of two-element arrays `a` and `b` and the filter equation is:
    ```
    a0*y[k] + a1*y[k-1] = b0*u[k] + b1*u[k-1]
    ```

    Denoting the `filter_coefficient` parameter by `N`, the following filters are
    available:
    - "none": The default, a simple finite difference approximation.
    - "forward": A filtered forward Euler discretization. The filter is:
        `a = [1, (N*dt - 1)]` and `b = [N, -N]`.
    - "backward": A filtered backward Euler discretization. The filter is:
        `a = [(1 + N*dt), -1]` and `b = [N, -N]`.
    - "bilinear": A filtered bilinear transform discretization. The filter is:
        `a = [(2 + N*dt), (-2 + N*dt)]` and `b = [2*N, -2*N]`.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The approximate derivative of the input signal.

    Parameters:
        dt:
            The time step of the discrete approximation.
        filter_type:
            One of "none", "forward", "backward", or "bilinear". This determines the
            type of filter used to approximate the derivative. The default is "none",
            corresponding to a simple backward difference approximation.
        filter_coefficient:
            The coefficient in the filter (`N` in the equations above). This is only
            used if `filter_type` is not "none". The default is 1.0.
    """

    @parameters(static=["filter_type", "filter_coefficient"])
    def __init__(self, dt, filter_type="none", filter_coefficient=1.0, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.declare_input_port()
        self._periodic_update_idx = self.declare_periodic_update()
        self.deriv_output = self.declare_output_port(
            period=dt,
            offset=0.0,
            prerequisites_of_calc=[self.input_ports[0].ticket],
        )

    def initialize(self, filter_type="none", filter_coefficient=1.0):
        # Determine the coefficients of the filter, if applicable
        # The filter is a pair of two-element array and the filter
        # equation is:
        # a0*y[k] + a1*y[k-1] = b0*u[k] + b1*u[k-1]
        self.filter = derivative_filter(
            N=filter_coefficient, dt=self.dt, filter_type=filter_type
        )

        self.declare_discrete_state(default_value=None, as_array=False)

        self.configure_periodic_update(
            self._periodic_update_idx,
            self._update,
            period=self.dt,
            offset=0.0,
        )

        # At t=0 we have no prior information, so the output will
        # be held from its initial value (zero). At t=dt, we have
        # a previous sample, so there is enough information to estimate
        # the derivative.
        self.configure_output_port(
            self.deriv_output,
            self._output,
            period=self.dt,
            offset=self.dt,
            prerequisites_of_calc=[self.input_ports[0].ticket],
        )

    def _output(self, _time, state, *inputs, **_params):
        # Compute the filtered derivative estimate
        (u,) = inputs
        b, a = self.filter
        y_prev = state.cache[self.deriv_output]
        u_prev = state.discrete_state
        y = (b[0] * u + b[1] * u_prev - a[1] * y_prev) / a[0]
        return y

    def _update(self, time, state, u, **params):
        # Every dt seconds, update the state to the current values
        return u

    def initialize_static_data(self, context):
        """Infer the size and dtype of the internal states"""
        # If building as part of a subsystem, this may not be fully connected yet.
        # That's fine, as long as it is connected by root context creation time.
        # This probably isn't a good long-term solution:
        #   see https://collimator.atlassian.net/browse/WC-51
        try:
            u = self.eval_input(context)
            self._default_discrete_state = u
            local_context = context[self.system_id].with_discrete_state(u)
            self._default_cache[self.deriv_output] = 0 * u
            local_context = local_context.with_cached_value(self.deriv_output, 0 * u)
            context = context.with_subcontext(self.system_id, local_context)

        except UpstreamEvalError:
            logger.debug(
                "DerivativeDiscrete.initialize_static_data: UpstreamEvalError. "
                "Continuing without default value initialization."
            )
        return super().initialize_static_data(context)

initialize_static_data(context)

Infer the size and dtype of the internal states

Source code in collimator/library/primitives.py
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
def initialize_static_data(self, context):
    """Infer the size and dtype of the internal states"""
    # If building as part of a subsystem, this may not be fully connected yet.
    # That's fine, as long as it is connected by root context creation time.
    # This probably isn't a good long-term solution:
    #   see https://collimator.atlassian.net/browse/WC-51
    try:
        u = self.eval_input(context)
        self._default_discrete_state = u
        local_context = context[self.system_id].with_discrete_state(u)
        self._default_cache[self.deriv_output] = 0 * u
        local_context = local_context.with_cached_value(self.deriv_output, 0 * u)
        context = context.with_subcontext(self.system_id, local_context)

    except UpstreamEvalError:
        logger.debug(
            "DerivativeDiscrete.initialize_static_data: UpstreamEvalError. "
            "Continuing without default value initialization."
        )
    return super().initialize_static_data(context)

DirectShootingNMPC

Bases: NonlinearMPCIpopt

Implementation of nonlinear MPC with a direct shooting transcription and IPOPT as the NLP solver.

Input ports

(0) x_0 : current state vector. (1) x_ref : reference state trajectory for the nonlinear MPC. (2) u_ref : reference input trajectory for the nonlinear MPC.

Output ports

(1) u_opt : the optimal control input to be applied at the current time step as determined by the nonlinear MPC.

Parameters:

Name Type Description Default
plant

LeafSystem or Diagram The plant to be controlled.

required
Q

Array State weighting matrix in the cost function.

required
QN

Array Terminal state weighting matrix in the cost function.

required
R

Array Control input weighting matrix in the cost function.

required
N

int The prediction horizon, an integer specifying the number of steps to predict. Note: prediction and control horizons are identical for now.

required
nh

int Number of minor steps to take within an RK4 major step.

required
dt

float: Major time step, a scalar indicating the increment in time for each step in the prediction and control horizons.

required
lb_u

Array Lower bound on the control input vector.

None
ub_u

Array Upper bound on the control input vector.

None
u_optvars_0

Array Initial guess for the control vector optimization variables in the NLP.

None
Source code in collimator/library/nmpc/direct_shooting_ipopt_nmpc.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class DirectShootingNMPC(NonlinearMPCIpopt):
    """
    Implementation of nonlinear MPC with a direct shooting transcription and IPOPT as
    the NLP solver.

    Input ports:
        (0) x_0 : current state vector.
        (1) x_ref : reference state trajectory for the nonlinear MPC.
        (2) u_ref : reference input trajectory for the nonlinear MPC.

    Output ports:
        (1) u_opt : the optimal control input to be applied at the current time step
                    as determined by the nonlinear MPC.

    Parameters:
        plant: LeafSystem or Diagram
            The plant to be controlled.

        Q: Array
            State weighting matrix in the cost function.

        QN: Array
            Terminal state weighting matrix in the cost function.

        R: Array
            Control input weighting matrix in the cost function.

        N: int
            The prediction horizon, an integer specifying the number of steps to
            predict. Note: prediction and control horizons are identical for now.

        nh: int
            Number of minor steps to take within an RK4 major step.

        dt: float:
            Major time step, a scalar indicating the increment in time for each step in
            the prediction and control horizons.

        lb_u: Array
            Lower bound on the control input vector.

        ub_u: Array
            Upper bound on the control input vector.

        u_optvars_0: Array
            Initial guess for the control vector optimization variables in the NLP.
    """

    def __init__(
        self,
        plant,
        Q,
        QN,
        R,
        N,
        nh,
        dt,
        lb_u=None,
        ub_u=None,
        u_optvars_0=None,
        name=None,
    ):
        self.plant = plant

        self.Q = Q
        self.QN = QN
        self.R = R

        self.N = N
        self.nh = nh
        self.dt = dt

        self.lb_u = lb_u
        self.ub_u = ub_u

        self.nx = Q.shape[0]
        self.nu = R.shape[0]

        if lb_u is None:
            self.lb_u = -1e20 * jnp.ones(self.nu)

        if ub_u is None:
            self.ub_u = 1e20 * jnp.ones(self.nu)

        # Currently guesses are not taken into account
        self.u_optvars_0 = u_optvars_0  # Currently does nothing
        if u_optvars_0 is None:
            u_optvars_0 = jnp.zeros((N, self.nu))

        self.ode_rhs = make_ode_rhs(plant, self.nu)

        nlp_structure_ipopt = NMPCProblemStructure(
            self.num_optvars,
            self._objective,
        )

        super().__init__(
            dt,
            self.nu,
            self.num_optvars,
            nlp_structure_ipopt,
            name=name,
        )

    @property
    def num_optvars(self):
        return self.N * self.nu

    @property
    def num_constraints(self):
        return 0

    @property
    def bounds_optvars(self):
        lb = jnp.tile(self.lb_u, self.N)
        ub = jnp.tile(self.ub_u, self.N)
        return (lb, ub)

    @property
    def bounds_constraints(self):
        c_lb = []
        c_ub = []
        return (c_lb, c_ub)

    @partial(jax.jit, static_argnames=("self",))
    def _objective(self, optvars, t0, x0, x_ref, u_ref):
        u_flat = optvars
        u = jnp.array(u_flat.reshape((self.N, self.nu)))

        x = jnp.zeros((self.N + 1, x0.size))
        x = x.at[0].set(x0)

        def _update_function(idx, x):
            t_major_start = t0 + self.dt * idx
            x_current = x[idx]
            u_current = u[idx]
            x_next = rk4_major_step_constant_u(
                t_major_start,
                x_current,
                u_current,
                self.dt,
                self.nh,
                self.ode_rhs,
            )
            return x.at[idx + 1].set(x_next)

        x = jax.lax.fori_loop(0, self.N, _update_function, x)

        xdiff = x - x_ref
        udiff = u - u_ref

        # compute sum of quadratic products for x_0 to x_{N-1}
        A = jnp.dot(xdiff[:-1], self.Q)
        qp_x_sum = jnp.sum(xdiff[:-1] * A, axis=None)

        # Compute quadratic product for the x_N
        xN = xdiff[-1]
        qp_x_N = jnp.dot(xN, jnp.dot(self.QN, xN))

        # compute sum of quadratic products for u_0 to u_{N-1}
        B = jnp.dot(udiff, self.R)
        qp_u_sum = jnp.sum(udiff * B, axis=None)

        # Sum the quadratic products
        total_sum = qp_x_sum + qp_x_N + qp_u_sum
        return total_sum

DirectTranscriptionNMPC

Bases: NonlinearMPCIpopt

Implementation of nonlinear MPC with direct transcription and IPOPT as the NLP solver.

Input ports

(0) x_0 : current state vector. (1) x_ref : reference state trajectory for the nonlinear MPC. (2) u_ref : reference input trajectory for the nonlinear MPC.

Output ports

(1) u_opt : the optimal control input to be applied at the current time step as determined by the nonlinear MPC.

Parameters:

Name Type Description Default
plant

LeafSystem or Diagram The plant to be controlled.

required
Q

Array State weighting matrix in the cost function.

required
QN

Array Terminal state weighting matrix in the cost function.

required
R

Array Control input weighting matrix in the cost function.

required
N

int The prediction horizon, an integer specifying the number of steps to predict. Note: prediction and control horizons are identical for now.

required
nh

int Number of minor steps to take within an RK4 major step.

required
dt

float: Major time step, a scalar indicating the increment in time for each step in the prediction and control horizons.

required
lb_x

Array Lower bound on the state vector.

None
ub_x

Array Upper bound on the state vector.

None
lb_u

Array Lower bound on the control input vector.

None
ub_u

Array Upper bound on the control input vector.

None
x_optvars_0

Array Initial guess for the state vector optimization variables in the NLP.

None
u_optvars_0

Array Initial guess for the control vector optimization variables in the NLP.

None
Source code in collimator/library/nmpc/direct_transcription_ipopt_nmpc.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
class DirectTranscriptionNMPC(NonlinearMPCIpopt):
    """
    Implementation of nonlinear MPC with direct transcription and IPOPT as the NLP
    solver.

    Input ports:
        (0) x_0 : current state vector.
        (1) x_ref : reference state trajectory for the nonlinear MPC.
        (2) u_ref : reference input trajectory for the nonlinear MPC.

    Output ports:
        (1) u_opt : the optimal control input to be applied at the current time step
                    as determined by the nonlinear MPC.

    Parameters:
        plant: LeafSystem or Diagram
            The plant to be controlled.

        Q: Array
            State weighting matrix in the cost function.

        QN: Array
            Terminal state weighting matrix in the cost function.

        R: Array
            Control input weighting matrix in the cost function.

        N: int
            The prediction horizon, an integer specifying the number of steps to
            predict. Note: prediction and control horizons are identical for now.

        nh: int
            Number of minor steps to take within an RK4 major step.

        dt: float:
            Major time step, a scalar indicating the increment in time for each step in
            the prediction and control horizons.

        lb_x: Array
            Lower bound on the state vector.

        ub_x: Array
            Upper bound on the state vector.

        lb_u: Array
            Lower bound on the control input vector.

        ub_u: Array
            Upper bound on the control input vector.

        x_optvars_0: Array
            Initial guess for the state vector optimization variables in the NLP.

        u_optvars_0: Array
            Initial guess for the control vector optimization variables in the NLP.
    """

    def __init__(
        self,
        plant,
        Q,
        QN,
        R,
        N,
        nh,
        dt,
        lb_x=None,
        ub_x=None,
        lb_u=None,
        ub_u=None,
        x_optvars_0=None,
        u_optvars_0=None,
        name=None,
    ):
        self.plant = plant

        self.Q = Q
        self.QN = QN
        self.R = R

        self.N = N
        self.nh = nh
        self.dt = dt

        self.lb_x = lb_x
        self.ub_x = ub_x
        self.lb_u = lb_u
        self.ub_u = ub_u

        self.nx = Q.shape[0]
        self.nu = R.shape[0]

        if lb_x is None:
            self.lb_x = -1e20 * jnp.ones(self.nx)

        if ub_x is None:
            self.ub_x = 1e20 * jnp.ones(self.nx)

        if lb_u is None:
            self.lb_u = -1e20 * jnp.ones(self.nu)

        if ub_u is None:
            self.ub_u = 1e20 * jnp.ones(self.nu)

        # Currently guesses are not taken into account
        self.x_optvars_0 = x_optvars_0  # Currently does nothing
        self.u_optvars_0 = u_optvars_0  # Currently does nothing
        if x_optvars_0 is None:
            x_optvars_0 = jnp.zeros((N + 1, self.nx))
        if u_optvars_0 is None:
            u_optvars_0 = jnp.zeros((N, self.nu))

        self.ode_rhs = make_ode_rhs(plant, self.nu)

        nlp_structure_ipopt = NMPCProblemStructure(
            self.num_optvars,
            self._objective,
            self._constraints,
        )

        super().__init__(
            dt,
            self.nu,
            self.num_optvars,
            nlp_structure_ipopt,
            name=name,
        )

    @property
    def num_optvars(self):
        return (self.N + 1) * self.nx + self.N * self.nu

    @property
    def num_constraints(self):
        return (self.N + 1) * self.nx

    @property
    def bounds_optvars(self):
        lb = jnp.hstack([jnp.tile(self.lb_u, self.N), jnp.tile(self.lb_x, self.N + 1)])
        ub = jnp.hstack([jnp.tile(self.ub_u, self.N), jnp.tile(self.ub_x, self.N + 1)])
        return (lb, ub)

    @property
    def bounds_constraints(self):
        c_lb = jnp.zeros(self.num_constraints)
        c_ub = jnp.zeros(self.num_constraints)
        return (c_lb, c_ub)

    @partial(jax.jit, static_argnames=("self",))
    def _objective(self, optvars, t0, x0, x_ref, u_ref):
        u_and_x_flat = optvars

        u = u_and_x_flat[: self.nu * self.N].reshape((self.N, self.nu))
        x = u_and_x_flat[self.nu * self.N :].reshape((self.N + 1, self.nx))

        xdiff = x - x_ref
        udiff = u - u_ref

        # compute sum of quadratic products for x_0 to x_{N-1}
        A = jnp.dot(xdiff[:-1], self.Q)
        qp_x_sum = jnp.sum(xdiff[:-1] * A, axis=None)

        # Compute quadratic product for the x_N
        xN = xdiff[-1]
        qp_x_N = jnp.dot(xN, jnp.dot(self.QN, xN))

        # compute sum of quadratic products for u_0 to u_{N-1}
        B = jnp.dot(udiff, self.R)
        qp_u_sum = jnp.sum(udiff * B, axis=None)

        # Sum the quadratic products
        total_sum = qp_x_sum + qp_x_N + qp_u_sum
        return total_sum

    @partial(jax.jit, static_argnames=("self",))
    def _constraints(self, optvars, t0, x0, x_ref, u_ref):
        u_and_x_flat = optvars
        u = u_and_x_flat[: self.nu * self.N].reshape((self.N, self.nu))
        x = u_and_x_flat[self.nu * self.N :].reshape((self.N + 1, self.nx))

        x_sim = jnp.zeros((self.N, x0.size))

        def _update_function(idx, x_sim_l):
            t_major_start = t0 + self.dt * idx
            x_current = x[idx]
            u_current = u[idx]
            x_next = rk4_major_step_constant_u(
                t_major_start,
                x_current,
                u_current,
                self.dt,
                self.nh,
                self.ode_rhs,
            )
            return x_sim_l.at[idx].set(x_next)

        x_sim = jax.lax.fori_loop(0, self.N, _update_function, x_sim)  # x1, x2, ..., xN

        c0 = x0 - x[0]
        c_others = x[1:] - x_sim

        c_all = jnp.hstack([c0.ravel(), c_others.ravel()])

        return c_all

DiscreteClock

Bases: LeafSystem

Source block that produces the time sampled at a fixed rate.

The block maintains the most recently sampled time as a discrete state, provided to the output port during the following interval. Graphically, a discrete clock sampled at 100 Hz would have the following time series:

  x(t)                  ●━
    |                   ┆
.03 |              ●━━━━○
    |              ┆
.02 |         ●━━━━○
    |         ┆
.01 |    ●━━━━○
    |    ┆
  0 ●━━━━○----+----+----+-- t
    0   .01  .02  .03  .04

The recorded states are the closed circles, which should be interpreted at index n as the value seen by all other blocks on the interval (t[n], t[n+1]).

Input ports

None

Output ports

(0) The sampled time.

Parameters:

Name Type Description Default
dt

The sampling period of the clock.

required
start_time

The simulation time at which the clock starts. Defaults to 0.

0
Source code in collimator/library/primitives.py
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
class DiscreteClock(LeafSystem):
    """Source block that produces the time sampled at a fixed rate.

    The block maintains the most recently sampled time as a discrete state, provided
    to the output port during the following interval. Graphically, a discrete clock
    sampled at 100 Hz would have the following time series:

    ```
      x(t)                  ●━
        |                   ┆
    .03 |              ●━━━━○
        |              ┆
    .02 |         ●━━━━○
        |         ┆
    .01 |    ●━━━━○
        |    ┆
      0 ●━━━━○----+----+----+-- t
        0   .01  .02  .03  .04
    ```

    The recorded states are the closed circles, which should be interpreted at index
    `n` as the value seen by all other blocks on the interval `(t[n], t[n+1])`.

    Input ports:
        None

    Output ports:
        (0) The sampled time.

    Parameters:
        dt:
            The sampling period of the clock.
        start_time:
            The simulation time at which the clock starts. Defaults to 0.
    """

    def __init__(self, dt, dtype=None, start_time=0, **kwargs):
        super().__init__(**kwargs)
        self.dtype = dtype or float
        start_time = cnp.array(start_time, dtype=self.dtype)

        self.declare_output_port(
            self._output,
            period=dt,
            offset=0.0,
            requires_inputs=False,
            default_value=start_time,
            prerequisites_of_calc=[DependencyTicket.time],
        )

    def _output(self, time, _state, *_inputs, **_params):
        return cnp.array(time, dtype=self.dtype)

DiscreteInitializer

Bases: LeafSystem

Discrete Initializer.

Outputs True for first discrete step, then outputs False there after. Or, outputs False for first discrete step, then outputs True there after. Practical for cases where it is necessary to have some signal fed initially by some initialization, but then after from else in the model.

Input ports

None

Output ports

(0) The dot product of the inputs.

Source code in collimator/library/primitives.py
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
class DiscreteInitializer(LeafSystem):
    """Discrete Initializer.

    Outputs True for first discrete step, then outputs False there after.
    Or, outputs False for first discrete step, then outputs True there after.
    Practical for cases where it is necessary to have some signal fed initially
    by some initialization, but then after from else in the model.

    Input ports:
        None

    Output ports:
        (0) The dot product of the inputs.
    """

    @parameters(dynamic=["initial_state"])
    def __init__(self, dt, initial_state=True, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.declare_output_port(self._output)
        self._periodic_update_idx = self.declare_periodic_update()

    def initialize(self, initial_state):
        self.declare_discrete_state(default_value=initial_state, dtype=cnp.bool_)
        self.configure_periodic_update(
            self._periodic_update_idx,
            self._update,
            period=cnp.inf,
            offset=self.dt,
        )

    def reset_default_values(self, initial_state):
        self.configure_discrete_state_default_value(default_value=initial_state)

    def _update(self, time, state, *_inputs, **_params):
        return cnp.logical_not(state.discrete_state)

    def _output(self, _time, state, *_inputs, **_params):
        return state.discrete_state

DiscreteTimeLinearQuadraticRegulator

Bases: LeafSystem

Linear Quadratic Regulator (LQR) for a discrete-time system: x[k+1] = A x[k] + B u[k]. Computes the optimal control input: u[k] = -K x[k], where u minimises the cost function over [0, ∞)]: J = ∑(x[k].T Q x[k] + u[k].T R u[k]).

Input ports

(0) x[k]: state vector of the system.

Output ports

(0) u[k]: optimal control vector.

Parameters:

Name Type Description Default
A

Array State matrix of the system.

required
B

Array Input matrix of the system.

required
Q

Array State cost matrix.

required
R

Array Input cost matrix.

required
dt

float Sampling period of the system.

required
Source code in collimator/library/lqr.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class DiscreteTimeLinearQuadraticRegulator(LeafSystem):
    """
    Linear Quadratic Regulator (LQR) for a discrete-time system:
            x[k+1] = A x[k] + B u[k].
    Computes the optimal control input:
            u[k] = -K x[k],
    where u minimises the cost function over [0, ∞)]:
            J = ∑(x[k].T Q x[k] + u[k].T R u[k]).

    Input ports:
        (0) x[k]: state vector of the system.

    Output ports:
        (0) u[k]: optimal control vector.

    Parameters:
        A: Array
            State matrix of the system.
        B: Array
            Input matrix of the system.
        Q: Array
            State cost matrix.
        R: Array
            Input cost matrix.
        dt: float
            Sampling period of the system.
    """

    def __init__(self, A, B, Q, R, dt, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.K, S, E = control.dlqr(A, B, Q, R)

        self.declare_input_port()  # for state x

        self.declare_output_port(
            self._get_opt_u,
            requires_inputs=True,
            period=dt,
            offset=0.0,
            default_value=jnp.zeros(B.shape[1]),
        )

    def _get_opt_u(self, time, state, x, **params):
        return jnp.matmul(-self.K, x)

DotProduct

Bases: ReduceBlock

Compute the dot product between the inputs.

This block dispatches to jax.numpy.dot, so the semantics, broadcasting rules, etc. are the same. See the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.dot.html

Input ports

(0) The first input vector. (1) The second input vector.

Output ports

(0) The dot product of the inputs.

Source code in collimator/library/primitives.py
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
class DotProduct(ReduceBlock):
    """Compute the dot product between the inputs.

    This block dispatches to `jax.numpy.dot`, so the semantics, broadcasting rules,
    etc. are the same.  See the JAX docs for details:
        https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.dot.html

    Input ports:
        (0) The first input vector.
        (1) The second input vector.

    Output ports:
        (0) The dot product of the inputs.
    """

    def __init__(self, **kwargs):
        super().__init__(2, self._compute_output, **kwargs)

    def _compute_output(self, inputs):
        return cnp.dot(inputs[0], inputs[1])

EdgeDetection

Bases: LeafSystem

Output is true only when the input signal changes in a specified way.

The block updates at a discrete rate, checking the boolean- or binary-valued input signal for changes. Available edge detection modes are: - "rising": Output is true when the input changes from False (0) to True (1). - "falling": Output is true when the input changes from True (1) to False (0). - "either": Output is true when the input changes in either direction

Input ports

(0) The input signal. Must be boolean or binary-valued.

Output ports

(0) The edge detection output signal. Boolean-valued.

Parameters:

Name Type Description Default
dt

The sampling period of the block.

required
edge_detection

One of "rising", "falling", or "either". Determines the type of edge detection performed by the block.

required
initial_state

The initial value of the output signal.

False
Source code in collimator/library/primitives.py
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
class EdgeDetection(LeafSystem):
    """Output is true only when the input signal changes in a specified way.

    The block updates at a discrete rate, checking the boolean- or binary-valued input
    signal for changes.  Available edge detection modes are:
        - "rising": Output is true when the input changes from False (0) to True (1).
        - "falling": Output is true when the input changes from True (1) to False (0).
        - "either": Output is true when the input changes in either direction

    Input ports:
        (0) The input signal. Must be boolean or binary-valued.

    Output ports:
        (0) The edge detection output signal. Boolean-valued.

    Parameters:
        dt:
            The sampling period of the block.
        edge_detection:
            One of "rising", "falling", or "either". Determines the type of edge
            detection performed by the block.
        initial_state:
            The initial value of the output signal.
    """

    class DiscreteStateType(NamedTuple):
        prev_input: Array
        output: bool

    @parameters(dynamic=["initial_state"], static=["edge_detection"])
    def __init__(self, dt, edge_detection, initial_state=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dt = dt
        self.declare_input_port()

        # Declare the periodic update
        self._periodic_update_idx = self.declare_periodic_update()

        # Declare the output port
        self._output_port_idx = self.declare_output_port(
            self._output,
            prerequisites_of_calc=[DependencyTicket.xd, self.input_ports[0].ticket],
            requires_inputs=False,
        )

    def initialize(self, edge_detection, initial_state):
        # Determine the type of edge detection
        _detection_funcs = {
            "rising": self._detect_rising,
            "falling": self._detect_falling,
            "either": self._detect_either,
        }
        if edge_detection not in _detection_funcs:
            raise ValueError(
                f"EdgeDetection block {self.name} has invalid selection "
                f"{edge_detection} for 'edge_detection'"
            )
        self._detect_edge = _detection_funcs[edge_detection]

        # The discrete state will contain the previous input value and the output
        self.declare_discrete_state(
            default_value=self.DiscreteStateType(
                prev_input=initial_state, output=False
            ),
            as_array=False,
        )
        self.configure_periodic_update(
            self._periodic_update_idx,
            self._update,
            period=self.dt,
            offset=0.0,
        )

        # Declare the output port
        self.configure_output_port(
            self._output_port_idx,
            self._output,
            prerequisites_of_calc=[DependencyTicket.xd, self.input_ports[0].ticket],
            requires_inputs=False,
        )

    def reset_default_values(self, initial_state):
        # The discrete state will contain the previous input value and the output
        self.configure_discrete_state_default_value(
            default_value=self.DiscreteStateType(
                prev_input=initial_state, output=False
            ),
            as_array=False,
        )

    def _update(self, time, state, *inputs, **params):
        # Update the stored previous state
        # and the output as the result of the edge detection function
        (e,) = inputs
        return self.DiscreteStateType(
            prev_input=e,
            output=self._detect_edge(time, state, e, **params),
        )

    def _output(self, _time, state, *_inputs, **_params):
        return state.discrete_state.output

    def _detect_rising(self, _time, state, *inputs, **_params):
        (e,) = inputs
        e_prev = state.discrete_state.prev_input
        e_prev = cnp.array(e_prev)
        e = cnp.array(e)
        not_e_prev = cnp.logical_not(e_prev)
        return cnp.logical_and(not_e_prev, e)

    def _detect_falling(self, _time, state, *inputs, **_params):
        (e,) = inputs
        e_prev = state.discrete_state.prev_input
        e_prev = cnp.array(e_prev)
        e = cnp.array(e)
        not_e = cnp.logical_not(e)
        return cnp.logical_and(e_prev, not_e)

    def _detect_either(self, _time, state, *inputs, **_params):
        (e,) = inputs
        e_prev = state.discrete_state.prev_input
        e_prev = cnp.array(e_prev)
        e = cnp.array(e)
        not_e_prev = cnp.logical_not(e_prev)
        not_e = cnp.logical_not(e)
        rising = cnp.logical_and(not_e_prev, e)
        falling = cnp.logical_and(e_prev, not_e)
        return cnp.logical_or(rising, falling)

Exponent

Bases: FeedthroughBlock

Compute the exponential of the input signal.

Input ports

(0) The input signal.

Output ports

(0) The exponential of the input signal.

Parameters:

Name Type Description Default
base

One of "exp" or "2". Determines the base of the exponential function.

required
Source code in collimator/library/primitives.py
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
class Exponent(FeedthroughBlock):
    """Compute the exponential of the input signal.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The exponential of the input signal.

    Parameters:
        base:
            One of "exp" or "2". Determines the base of the exponential function.
    """

    @parameters(static=["base"])
    def __init__(self, base, **kwargs):
        super().__init__(None, **kwargs)

    def initialize(self, base):
        func_lookup = {"exp": cnp.exp, "2": cnp.exp2}
        if base not in func_lookup:
            raise BlockParameterError(
                message=f"Exponent block {self.name} has invalid selection {base} for 'base'. Valid selections: "
                + ", ".join([k for k in func_lookup.keys()]),
                parameter_name="base",
            )
        self.replace_op(func_lookup[base])

ExtendedKalmanFilter

Bases: KalmanFilterBase

Extended Kalman Filter (EKF) for the following system:

```
x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
y[n]   = g(x[n], u[n]) + v[n]

E(w[n]) = E(v[n]) = 0
E(w[n]w'[n]) = Q(t[n], x[n], u[n])
E(v[n]v'[n] = R(t[n])
E(w[n]v'[n] = N(t[n]) = 0
```

f and g are discrete-time functions of state x[n] and control u[n], while RandGare discrete-time functions of timet[n].Qis a discrete-time function oft[n], x[n], u[n]`. This last aspect is included for zero-order-hold discretization of a continuous-time system

Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
dt

float Time step of the discrete-time system

required
forward

Callable A function with signature f(x[n], u[n]) -> x[n+1] that represents f in the above equations.

required
observation

Callable A function with signature g(x[n], u[n]) -> y[n] that represents g in the above equations.

required
G_func

Callable A function with signature G(t[n]) -> G[n] that represents G in the above equations.

required
Q_func

Callable A function with signature Q(t[n], x[n], u[n]) -> Q[n] that represents Q in the above equations.

required
R_func

Callable A function with signature R(t[n]) -> R[n] that represents R in the above equations.

required
x_hat_0

ndarray Initial state estimate

required
P_hat_0

ndarray Initial state covariance matrix estimate

required
Source code in collimator/library/state_estimators/extended_kalman_filter.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class ExtendedKalmanFilter(KalmanFilterBase):
    """
    Extended Kalman Filter (EKF) for the following system:

        ```
        x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
        y[n]   = g(x[n], u[n]) + v[n]

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Q(t[n], x[n], u[n])
        E(v[n]v'[n] = R(t[n])
        E(w[n]v'[n] = N(t[n]) = 0
        ```

    `f` and `g` are discrete-time functions of state `x[n]` and control `u[n]`,
    while R` and `G` are discrete-time functions of time `t[n]`. `Q` is a discrete-time
    function of `t[n], x[n], u[n]`. This last aspect is included for zero-order-hold
    discretization of a continuous-time system

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        dt: float
            Time step of the discrete-time system
        forward: Callable
            A function with signature f(x[n], u[n]) -> x[n+1] that represents `f` in
            the above equations.
        observation: Callable
            A function with signature g(x[n], u[n]) -> y[n] that represents `g` in
            the above equations.
        G_func: Callable
            A function with signature G(t[n]) -> G[n] that represents `G` in
            the above equations.
        Q_func: Callable
            A function with signature Q(t[n], x[n], u[n]) -> Q[n] that represents `Q`
            in the above equations.
        R_func: Callable
            A function with signature R(t[n]) -> R[n] that represents `R` in
            the above equations.
        x_hat_0: ndarray
            Initial state estimate
        P_hat_0: ndarray
            Initial state covariance matrix estimate
    """

    @parameters(
        static=[
            "dt",
            "forward",
            "observation",
            "G_func",
            "Q_func",
            "R_func",
            "x_hat_0",
            "P_hat_0",
        ],
    )
    def __init__(
        self,
        dt,
        forward,
        observation,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
        is_feedthrough=True,  # TODO: determine automatically?
        name=None,
        **kwargs,
    ):
        super().__init__(dt, x_hat_0, P_hat_0, is_feedthrough, name, **kwargs)

    def initialize(
        self,
        dt,
        forward,
        observation,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
    ):
        self.G_func = G_func
        self.Q_func = Q_func
        self.R_func = R_func

        self.nx = x_hat_0.size
        self.ny = self.R_func(0.0).shape[0]

        self.forward = forward
        self.observation = observation

        self.jac_forward = jax.jacfwd(forward)
        self.jac_observation = jax.jacfwd(observation)

        self.eye_x = jnp.eye(self.nx)

    def _correct(self, time, x_hat_minus, P_hat_minus, *inputs):
        u, y = inputs
        y = jnp.atleast_1d(y)

        C = self.jac_observation(x_hat_minus, u).reshape((self.ny, self.nx))

        R = self.R_func(time)

        # TODO: improved numerics to avoud computing explicit inverse
        K = P_hat_minus @ C.T @ jnp.linalg.inv(C @ P_hat_minus @ C.T + R)

        x_hat_plus = x_hat_minus + jnp.dot(
            K, y - self.observation(x_hat_minus, u)
        )  # n|n

        P_hat_plus = jnp.matmul(self.eye_x - jnp.matmul(K, C), P_hat_minus)  # n|n

        return x_hat_plus, P_hat_plus

    def _propagate(self, time, x_hat_plus, P_hat_plus, *inputs):
        # Predict -- x_hat_plus of current step is propagated to be the
        # x_hat_minus of the next step
        # k+1|k in current step is n|n-1 for next step

        u, y = inputs
        u = jnp.atleast_1d(u)

        A = self.jac_forward(x_hat_plus, u).reshape((self.nx, self.nx))

        G = self.G_func(time)
        Q = self.Q_func(time, x_hat_plus, u)
        GQGT = G @ Q @ G.T

        x_hat_minus = self.forward(x_hat_plus, u)  # n+1|n
        P_hat_minus = A @ P_hat_plus @ A.T + GQGT  # n+1|n

        return x_hat_minus, P_hat_minus

    #######################################
    # Make filter for a continuous plant  #
    #######################################

    @staticmethod
    @with_resolved_parameters
    def for_continuous_plant(
        plant,
        dt,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
        discretization_method="euler",
        discretized_noise=False,
        name=None,
        ui_id=None,
    ):
        """
        Extended Kalman Filter system for a continuous-time plant.

        The input plant contains the deterministic forms of the forward and observation
        operators:

        ```
            dx/dt = f(x,u)
            y = g(x,u)
        ```

        Note: (i) Only plants with one vector-valued input and one vector-valued output
        are currently supported. Furthermore, the plant LeafSystem/Diagram should have
        only one vector-valued integrator; (ii) the user may pass a plant with
        disturbances (not recommended) as the input plant. In this case, the forward
        and observation evaluations will be corrupted by noise.

        A plant with disturbances of the following form is then considered:

        ```
            dx/dt = f(x,u) + G(t) w         -- (C1)
            y = g(x,u) +  v                 -- (C2)
        ```

        where:

            `w` represents the process noise,
            `v` represents the measurement noise,

        and

        ```
            E(w) = E(v) = 0
            E(ww') = Q(t)
            E(vv') = R(t)
            E(wv') = N(t) = 0
        ```

        This plant is discretized to obtain the following form:

        ```
            x[n+1] = fd(x[n], u[n]) + Gd w[n]  -- (D1)
            y[n]   = gd(x[n], u[n]) + v[n]     -- (D2)

            E(w[n]) = E(v[n]) = 0
            E(w[n]w'[n]) = Qd
            E(v[n]v'[n] = Rd
            E(w[n]v'[n] = Nd = 0
        ```

        The above discretization is performed either via the `euler` or the `zoh`
        method, and an Extended Kalman Filter estimator for the system of equations
        (D1) and (D2) is returned.

        Note: If `discretized_noise` is True, then it is assumed that the user is
        directly providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
        continuous-time Q, R, and G, and Gd is set to an Identity matrix.

        The returned system will have:

        Input ports:
            (0) u[n] : control vector at timestep n
            (1) y[n] : measurement vector at timestep n

        Output ports:
            (1) x_hat[n] : state vector estimate at timestep n

        Parameters:
            plant : a `Plant` object which can be a LeafSystem or a Diagram.
            dt: float
                Time step for the discretization.
            G_func: Callable
                A function with signature G(t) -> G that represents `G` in
                the continuous-time equations (C1) and (C2).
            Q_func: Callable
                A function with signature Q(t) -> Q that represents `Q` in
                the continuous-time equations (C1) and (C2).
            R_func: Callable
                A function with signature R(t) -> R that represents `R` in
                the continuous-time equations (C1) and (C2).
            x_hat_0: ndarray
                Initial state estimate
            P_hat_0: ndarray
                Initial state covariance matrix estimate. If `None`, an Identity
                matrix is assumed.
            discretization_method: str ("euler" or "zoh")
                Method to discretize the continuous-time plant. Default is "euler".
            discretized_noise: bool
                Whether the user is directly providing Gd, Qd and Rd. Default is False.
                If True, `G_func`, `Q_func`, and `R_func` provide Gd(t), Qd(t), and
                Rd(t), respectively.
        """

        (
            forward,
            observation,
            Gd_func,
            Qd_func,
            Rd_func,
        ) = prepare_continuous_plant_for_nonlinear_kalman_filter(
            plant,
            dt,
            G_func,
            Q_func,
            R_func,
            x_hat_0,
            discretization_method,
            discretized_noise,
        )

        nx = x_hat_0.size
        if P_hat_0 is None:
            P_hat_0 = jnp.eye(nx)

        # TODO: If Gd_func is None, compute Gd automatically with u = u + w

        ekf = ExtendedKalmanFilter(
            dt,
            forward,
            observation,
            Gd_func,
            Qd_func,
            Rd_func,
            x_hat_0,
            P_hat_0,
            name=name,
            ui_id=ui_id,
        )

        return ekf

    ###################################################################################
    # Make filter from direct specification of forward/observaton operators and noise #
    ###################################################################################

    @staticmethod
    @with_resolved_parameters
    def from_operators(
        dt,
        forward,
        observation,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
        name=None,
        ui_id=None,
    ):
        """
        Extended Kalman Filter (UKF) for the following system:

        ```
            x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
            y[n]   = g(x[n], u[n]) + v[n]

            E(w[n]) = E(v[n]) = 0
            E(w[n]w'[n]) = Q(t[n], x[n], u[n])
            E(v[n]v'[n] = R(t[n])
            E(w[n]v'[n] = N(t[n]) = 0
        ```

        `f` and `g` are discrete-time functions of state `x[n]` and control `u[n]`,
        while `Q` and `R` and `G` are discrete-time functions of time `t[n]`.

        Input ports:
            (0) u[n] : control vector at timestep n
            (1) y[n] : measurement vector at timestep n

        Output ports:
            (1) x_hat[n] : state vector estimate at timestep n

        Parameters:
            dt: float
                Time step of the discrete-time system
            forward: Callable
                A function with signature f(x[n], u[n]) -> x[n+1] that represents `f`
                in the above equations.
            observation: Callable
                A function with signature g(x[n], u[n]) -> y[n] that represents `g` in
                the above equations.
            G_func: Callable
                A function with signature G(t[n]) -> G[n] that represents `G` in
                the above equations.
            Q_func: Callable
                A function with signature Q(t[n]) -> Q[n] that represents
                `Q` in the above equations.
            R_func: Callable
                A function with signature R(t[n]) -> R[n] that represents `R` in
                the above equations.
            x_hat_0: ndarray
                Initial state estimate
            P_hat_0: ndarray
                Initial state covariance matrix estimate
        """

        def Q_func_aug(t, x_k, u_k):
            return Q_func(t)

        ekf = ExtendedKalmanFilter(
            dt,
            forward,
            observation,
            G_func,
            Q_func_aug,
            R_func,
            x_hat_0,
            P_hat_0,
            name=name,
            ui_id=ui_id,
        )

        return ekf

for_continuous_plant(plant, dt, G_func, Q_func, R_func, x_hat_0, P_hat_0, discretization_method='euler', discretized_noise=False, name=None, ui_id=None) staticmethod

Extended Kalman Filter system for a continuous-time plant.

The input plant contains the deterministic forms of the forward and observation operators:

    dx/dt = f(x,u)
    y = g(x,u)

Note: (i) Only plants with one vector-valued input and one vector-valued output are currently supported. Furthermore, the plant LeafSystem/Diagram should have only one vector-valued integrator; (ii) the user may pass a plant with disturbances (not recommended) as the input plant. In this case, the forward and observation evaluations will be corrupted by noise.

A plant with disturbances of the following form is then considered:

    dx/dt = f(x,u) + G(t) w         -- (C1)
    y = g(x,u) +  v                 -- (C2)

where:

`w` represents the process noise,
`v` represents the measurement noise,

and

    E(w) = E(v) = 0
    E(ww') = Q(t)
    E(vv') = R(t)
    E(wv') = N(t) = 0

This plant is discretized to obtain the following form:

    x[n+1] = fd(x[n], u[n]) + Gd w[n]  -- (D1)
    y[n]   = gd(x[n], u[n]) + v[n]     -- (D2)

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Qd
    E(v[n]v'[n] = Rd
    E(w[n]v'[n] = Nd = 0

The above discretization is performed either via the euler or the zoh method, and an Extended Kalman Filter estimator for the system of equations (D1) and (D2) is returned.

Note: If discretized_noise is True, then it is assumed that the user is directly providing Gd, Qd and Rd. If False, then Qd and Rd are computed from continuous-time Q, R, and G, and Gd is set to an Identity matrix.

The returned system will have:

Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
plant

a Plant object which can be a LeafSystem or a Diagram.

required
dt

float Time step for the discretization.

required
G_func

Callable A function with signature G(t) -> G that represents G in the continuous-time equations (C1) and (C2).

required
Q_func

Callable A function with signature Q(t) -> Q that represents Q in the continuous-time equations (C1) and (C2).

required
R_func

Callable A function with signature R(t) -> R that represents R in the continuous-time equations (C1) and (C2).

required
x_hat_0

ndarray Initial state estimate

required
P_hat_0

ndarray Initial state covariance matrix estimate. If None, an Identity matrix is assumed.

required
discretization_method

str ("euler" or "zoh") Method to discretize the continuous-time plant. Default is "euler".

'euler'
discretized_noise

bool Whether the user is directly providing Gd, Qd and Rd. Default is False. If True, G_func, Q_func, and R_func provide Gd(t), Qd(t), and Rd(t), respectively.

False
Source code in collimator/library/state_estimators/extended_kalman_filter.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
@staticmethod
@with_resolved_parameters
def for_continuous_plant(
    plant,
    dt,
    G_func,
    Q_func,
    R_func,
    x_hat_0,
    P_hat_0,
    discretization_method="euler",
    discretized_noise=False,
    name=None,
    ui_id=None,
):
    """
    Extended Kalman Filter system for a continuous-time plant.

    The input plant contains the deterministic forms of the forward and observation
    operators:

    ```
        dx/dt = f(x,u)
        y = g(x,u)
    ```

    Note: (i) Only plants with one vector-valued input and one vector-valued output
    are currently supported. Furthermore, the plant LeafSystem/Diagram should have
    only one vector-valued integrator; (ii) the user may pass a plant with
    disturbances (not recommended) as the input plant. In this case, the forward
    and observation evaluations will be corrupted by noise.

    A plant with disturbances of the following form is then considered:

    ```
        dx/dt = f(x,u) + G(t) w         -- (C1)
        y = g(x,u) +  v                 -- (C2)
    ```

    where:

        `w` represents the process noise,
        `v` represents the measurement noise,

    and

    ```
        E(w) = E(v) = 0
        E(ww') = Q(t)
        E(vv') = R(t)
        E(wv') = N(t) = 0
    ```

    This plant is discretized to obtain the following form:

    ```
        x[n+1] = fd(x[n], u[n]) + Gd w[n]  -- (D1)
        y[n]   = gd(x[n], u[n]) + v[n]     -- (D2)

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Qd
        E(v[n]v'[n] = Rd
        E(w[n]v'[n] = Nd = 0
    ```

    The above discretization is performed either via the `euler` or the `zoh`
    method, and an Extended Kalman Filter estimator for the system of equations
    (D1) and (D2) is returned.

    Note: If `discretized_noise` is True, then it is assumed that the user is
    directly providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
    continuous-time Q, R, and G, and Gd is set to an Identity matrix.

    The returned system will have:

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        plant : a `Plant` object which can be a LeafSystem or a Diagram.
        dt: float
            Time step for the discretization.
        G_func: Callable
            A function with signature G(t) -> G that represents `G` in
            the continuous-time equations (C1) and (C2).
        Q_func: Callable
            A function with signature Q(t) -> Q that represents `Q` in
            the continuous-time equations (C1) and (C2).
        R_func: Callable
            A function with signature R(t) -> R that represents `R` in
            the continuous-time equations (C1) and (C2).
        x_hat_0: ndarray
            Initial state estimate
        P_hat_0: ndarray
            Initial state covariance matrix estimate. If `None`, an Identity
            matrix is assumed.
        discretization_method: str ("euler" or "zoh")
            Method to discretize the continuous-time plant. Default is "euler".
        discretized_noise: bool
            Whether the user is directly providing Gd, Qd and Rd. Default is False.
            If True, `G_func`, `Q_func`, and `R_func` provide Gd(t), Qd(t), and
            Rd(t), respectively.
    """

    (
        forward,
        observation,
        Gd_func,
        Qd_func,
        Rd_func,
    ) = prepare_continuous_plant_for_nonlinear_kalman_filter(
        plant,
        dt,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        discretization_method,
        discretized_noise,
    )

    nx = x_hat_0.size
    if P_hat_0 is None:
        P_hat_0 = jnp.eye(nx)

    # TODO: If Gd_func is None, compute Gd automatically with u = u + w

    ekf = ExtendedKalmanFilter(
        dt,
        forward,
        observation,
        Gd_func,
        Qd_func,
        Rd_func,
        x_hat_0,
        P_hat_0,
        name=name,
        ui_id=ui_id,
    )

    return ekf

from_operators(dt, forward, observation, G_func, Q_func, R_func, x_hat_0, P_hat_0, name=None, ui_id=None) staticmethod

Extended Kalman Filter (UKF) for the following system:

    x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
    y[n]   = g(x[n], u[n]) + v[n]

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Q(t[n], x[n], u[n])
    E(v[n]v'[n] = R(t[n])
    E(w[n]v'[n] = N(t[n]) = 0

f and g are discrete-time functions of state x[n] and control u[n], while Q and R and G are discrete-time functions of time t[n].

Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
dt

float Time step of the discrete-time system

required
forward

Callable A function with signature f(x[n], u[n]) -> x[n+1] that represents f in the above equations.

required
observation

Callable A function with signature g(x[n], u[n]) -> y[n] that represents g in the above equations.

required
G_func

Callable A function with signature G(t[n]) -> G[n] that represents G in the above equations.

required
Q_func

Callable A function with signature Q(t[n]) -> Q[n] that represents Q in the above equations.

required
R_func

Callable A function with signature R(t[n]) -> R[n] that represents R in the above equations.

required
x_hat_0

ndarray Initial state estimate

required
P_hat_0

ndarray Initial state covariance matrix estimate

required
Source code in collimator/library/state_estimators/extended_kalman_filter.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@staticmethod
@with_resolved_parameters
def from_operators(
    dt,
    forward,
    observation,
    G_func,
    Q_func,
    R_func,
    x_hat_0,
    P_hat_0,
    name=None,
    ui_id=None,
):
    """
    Extended Kalman Filter (UKF) for the following system:

    ```
        x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
        y[n]   = g(x[n], u[n]) + v[n]

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Q(t[n], x[n], u[n])
        E(v[n]v'[n] = R(t[n])
        E(w[n]v'[n] = N(t[n]) = 0
    ```

    `f` and `g` are discrete-time functions of state `x[n]` and control `u[n]`,
    while `Q` and `R` and `G` are discrete-time functions of time `t[n]`.

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        dt: float
            Time step of the discrete-time system
        forward: Callable
            A function with signature f(x[n], u[n]) -> x[n+1] that represents `f`
            in the above equations.
        observation: Callable
            A function with signature g(x[n], u[n]) -> y[n] that represents `g` in
            the above equations.
        G_func: Callable
            A function with signature G(t[n]) -> G[n] that represents `G` in
            the above equations.
        Q_func: Callable
            A function with signature Q(t[n]) -> Q[n] that represents
            `Q` in the above equations.
        R_func: Callable
            A function with signature R(t[n]) -> R[n] that represents `R` in
            the above equations.
        x_hat_0: ndarray
            Initial state estimate
        P_hat_0: ndarray
            Initial state covariance matrix estimate
    """

    def Q_func_aug(t, x_k, u_k):
        return Q_func(t)

    ekf = ExtendedKalmanFilter(
        dt,
        forward,
        observation,
        G_func,
        Q_func_aug,
        R_func,
        x_hat_0,
        P_hat_0,
        name=name,
        ui_id=ui_id,
    )

    return ekf

FeedthroughBlock

Bases: LeafSystem

Simple feedthrough blocks with a function of a single input

Source code in collimator/library/generic.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class FeedthroughBlock(LeafSystem):
    """Simple feedthrough blocks with a function of a single input"""

    def __init__(self, func, parameters={}, **kwargs):
        super().__init__(**kwargs)
        self.declare_input_port()
        self._output_port_idx = self.declare_output_port(
            None,
            prerequisites_of_calc=[self.input_ports[0].ticket],
            requires_inputs=True,
        )
        self.replace_op(func)

    def replace_op(self, func):
        def _callback(time, state, *inputs, **parameters):
            return func(*inputs, **parameters)

        self.configure_output_port(
            self._output_port_idx,
            _callback,
            prerequisites_of_calc=[self.input_ports[0].ticket],
            requires_inputs=True,
        )

FilterDiscrete

Bases: LeafSystem

Finite Impulse Response (FIR) filter.

Similar to https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html Note: does not implement the IIR filter.

Input ports

(0) The input signal.

Output ports

(0) The filtered signal.

Parameters:

Name Type Description Default
b_coefficients

Array of filter coefficients.

required
Source code in collimator/library/primitives.py
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
class FilterDiscrete(LeafSystem):
    """Finite Impulse Response (FIR) filter.

    Similar to https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html
    Note: does not implement the IIR filter.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The filtered signal.

    Parameters:
        b_coefficients:
            Array of filter coefficients.
    """

    @parameters(static=["b_coefficients"])
    def __init__(
        self,
        dt,
        b_coefficients,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.dt = dt
        self.declare_input_port()
        self._periodic_update_idx = self.declare_periodic_update()
        self._output_port_idx = self.declare_output_port()

    def initialize(self, b_coefficients):
        initial_state = cnp.zeros(len(b_coefficients) - 1)
        self.declare_discrete_state(default_value=initial_state)

        self.is_feedthrough = bool(b_coefficients[0] != 0)
        self.b_coefficients = b_coefficients
        prerequisites_of_calc = []
        if self.is_feedthrough:
            prerequisites_of_calc.append(self.input_ports[0].ticket)

        self.configure_periodic_update(
            self._periodic_update_idx,
            self._update,
            period=self.dt,
            offset=self.dt,
        )

        self.configure_output_port(
            self._output_port_idx,
            self._output,
            period=self.dt,
            offset=self.dt,
            requires_inputs=self.is_feedthrough,
            prerequisites_of_calc=prerequisites_of_calc,
        )

    def _update(self, _time, state, u, **_parameters):
        xd = state.discrete_state
        return cnp.concatenate([cnp.atleast_1d(u), xd[:-1]])

    def _output(self, time, state, *inputs, **parameters):
        xd = state.discrete_state

        y = cnp.sum(cnp.dot(self.b_coefficients[1:], xd))

        if self.is_feedthrough:
            (u,) = inputs
            y += u * self.b_coefficients[0]

        return y

FiniteHorizonLinearQuadraticRegulator

Bases: LeafSystem

Finite Horizon Linear Quadratic Regulator (LQR) for a continuous-time system. Solves the Riccati Differential Equation (RDE) to compute the optimal control for the following finitie horizon cost function over [t0, tf]:

Minimise cost J:

J = [x(tf) - xd(tf)].T Qf [x(tf) - xd(tf)]
    + ∫[(x(t) - xd(t)].T Q [(x(t) - xd(t)] dt
    + ∫[(u(t) - ud(t)].T R [(u(t) - ud(t)] dt
    + 2 ∫[(x(t) - xd(t)].T N [(u(t) - ud(t)] dt

subject to the constraints:

dx(t)/dt - dx0(t)/dt = A [x(t)-x0(t)] + B [u(t)-u0(t)] - c(t),

where, x(t) is the state vector, u(t) is the control vector, xd(t) is the desired state vector, ud(t) is the desired control vector, x0(t) is the nominal state vector, u0(t) is the nominal control vector, Q, R, and N are the state, input, and cross cost matrices, Qf is the final state cost matrix,

and A, B, and c are computed from linearisation of the plant df/dx = f(x, u) around the nominal trajectory (x0(t), u0(t)).

A = df/dx(x0(t), u0(t), t)
B = df/du(x0(t), u0(t), t)
c = f(x0(t), u0(t), t) - dx0(t)/dt

The optimal control u obtained by the solution of the above problem is output.

See Section 8.5.1 of https://underactuated.csail.mit.edu/lqr.html#finite_horizon

Parameters:

Name Type Description Default
t0

float Initial time of the finite horizon.

required
tf

float Final time of the finite horizon.

required
plant

a Plant object which can be a LeafSystem or a Diagram. The plant to be controlled. This represents df/dx = f(x, u).

required
Qf

Array Final state cost matrix.

required
func_Q

Callable A function that returns the state cost matrix Q at time t: func_Q(t)->Q

required
func_R

Callable A function that returns the input cost matrix R at time t: func_R(t)->R

required
func_N

Callable A function that returns the cross cost matrix N at time t. func_N(t)->N

required
func_x_0

Callable A function that returns the nominal state vector x0 at time t. func_x_0(t)->x0

required
func_u_0

Callable A function that returns the nominal control vector u0 at time t. func_u_0(t)->u0

required
func_x_d

Callable A function that returns the desired state vector xd at time t. func_x_d(t)->xd. If None, assumed to be the same as the nominal trajectory.

None
func_u_d

Callable A function that returns the desired control vector ud at time t. func_u_d(t)->ud. If None, assumed to be the same as the nominal trajectory.

None
Source code in collimator/library/lqr.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class FiniteHorizonLinearQuadraticRegulator(LeafSystem):
    """
    Finite Horizon Linear Quadratic Regulator (LQR) for a continuous-time system.
    Solves the Riccati Differential Equation (RDE) to compute the optimal control
    for the following finitie horizon cost function over [t0, tf]:

    Minimise cost J:

        J = [x(tf) - xd(tf)].T Qf [x(tf) - xd(tf)]
            + ∫[(x(t) - xd(t)].T Q [(x(t) - xd(t)] dt
            + ∫[(u(t) - ud(t)].T R [(u(t) - ud(t)] dt
            + 2 ∫[(x(t) - xd(t)].T N [(u(t) - ud(t)] dt

    subject to the constraints:

    dx(t)/dt - dx0(t)/dt = A [x(t)-x0(t)] + B [u(t)-u0(t)] - c(t),

    where,
        x(t) is the state vector,
        u(t) is the control vector,
        xd(t) is the desired state vector,
        ud(t) is the desired control vector,
        x0(t) is the nominal state vector,
        u0(t) is the nominal control vector,
        Q, R, and N are the state, input, and cross cost matrices,
        Qf is the final state cost matrix,

    and A, B, and c are computed from linearisation of the plant `df/dx = f(x, u)`
    around the nominal trajectory (x0(t), u0(t)).

        A = df/dx(x0(t), u0(t), t)
        B = df/du(x0(t), u0(t), t)
        c = f(x0(t), u0(t), t) - dx0(t)/dt

    The optimal control `u` obtained by the solution of the above problem is output.

    See Section 8.5.1 of https://underactuated.csail.mit.edu/lqr.html#finite_horizon

    Parameters:
        t0 : float
            Initial time of the finite horizon.
        tf : float
            Final time of the finite horizon.
        plant : a `Plant` object which can be a LeafSystem or a Diagram.
            The plant to be controlled. This represents `df/dx = f(x, u)`.
        Qf : Array
            Final state cost matrix.
        func_Q : Callable
            A function that returns the state cost matrix Q at time `t`: `func_Q(t)->Q`
        func_R : Callable
            A function that returns the input cost matrix R at time `t`: `func_R(t)->R`
        func_N : Callable
            A function that returns the cross cost matrix N at time `t`. `func_N(t)->N`
        func_x_0 : Callable
            A function that returns the nominal state vector `x0` at time `t`.
            func_x_0(t)->x0
        func_u_0 : Callable
            A function that returns the nominal control vector `u0` at time `t`.
            func_u_0(t)->u0
        func_x_d : Callable
            A function that returns the desired state vector `xd` at time `t`.
            func_x_d(t)->xd.  If None, assumed to be the same as the nominal trajectory.
        func_u_d : Callable
            A function that returns the desired control vector `ud` at time `t`.
            func_u_d(t)->ud.  If None, assumed to be the same as the nominal trajectory.
    """

    def __init__(
        self,
        t0,
        tf,
        plant,
        Qf,
        func_Q,
        func_R,
        func_N,
        func_x_0,
        func_u_0,
        func_x_d=None,
        func_u_d=None,
        name=None,
    ):
        super().__init__(name=name)

        self.t0 = t0
        self.tf = tf

        if func_x_d is None:
            func_x_d = func_x_0

        if func_u_d is None:
            func_u_d = func_u_0

        self.func_R = func_R
        self.func_N = func_N
        self.func_x_0 = func_x_0
        self.func_u_0 = func_u_0
        self.func_x_d = func_x_d
        self.func_u_d = func_u_d

        func_dot_x_0 = jax.jacfwd(func_x_0)
        nu = func_R(t0).shape[0]

        ode_rhs = make_ode_rhs(plant, nu)
        get_A = jax.jacfwd(ode_rhs, argnums=0)
        self.get_B = jax.jacfwd(ode_rhs, argnums=1)

        @jax.jit
        def rde(t, rde_state, args):
            t = -t
            Sxx, sx = rde_state

            Sxx = (Sxx + Sxx.T) / 2.0

            # Get nominal trajectories, desired trajectories, and cost matrices
            x_0 = func_x_0(t)
            u_0 = func_u_0(t)

            x_d = func_x_d(t)
            u_d = func_u_d(t)

            Q = func_Q(t)
            R = func_R(t)
            N = func_N(t)

            # Calculate dynamics mismatch due to nominal traj not satisfying dynamics
            dot_x_0 = func_dot_x_0(t)
            dot_x_0_eval = ode_rhs(x_0, u_0, t)
            c = dot_x_0_eval - dot_x_0

            #  Get linearisation around x_0, u_0
            A = get_A(x_0, u_0, t)
            B = self.get_B(x_0, u_0, t)

            #  Desired trajectories relative to nominal
            x_d_0 = x_d - x_0
            u_d_0 = u_d - u_0

            #  Compute RHS of RDE
            qx = -jnp.dot(Q, x_d_0) - jnp.dot(N, u_d_0)
            ru = -jnp.dot(R, u_d_0) - jnp.dot(N.T, x_d_0)

            N_plus_Sxx_B = N + jnp.matmul(Sxx, B)

            Rinv = jnp.linalg.inv(R)
            Sxx_A = jnp.matmul(Sxx, A)

            dot_Sxx = (
                Q
                - jnp.matmul(N_plus_Sxx_B, jnp.matmul(Rinv, N_plus_Sxx_B.T))
                + Sxx_A
                + Sxx_A.T
            )

            dot_sx = (
                qx
                - jnp.dot(N_plus_Sxx_B, jnp.dot(Rinv, ru + jnp.dot(B.T, sx)))
                + jnp.dot(A.T, sx)
                + jnp.dot(Sxx, c)
            )

            return (dot_Sxx, dot_sx)

        term = diffrax.ODETerm(rde)
        solver = diffrax.Tsit5()
        stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5, dtmax=0.1)
        saveat = diffrax.SaveAt(dense=True)

        # TODO: Use utilities in ../simulation/ for reduced reliance on diffrax
        self.sol_rde = diffrax.diffeqsolve(
            term,
            solver,
            -tf,
            -t0,
            y0=(Qf, -jnp.dot(Qf, func_x_d(tf) - func_x_0(tf))),
            dt0=0.0001,
            saveat=saveat,
            stepsize_controller=stepsize_controller,
        )

        # Input: current state (x)
        self.declare_input_port()

        # Output port: Optimal finite horizon LQR control
        self.declare_output_port(self._eval_output, default_value=jnp.zeros(nu))

    def _eval_output(self, time, state, x, **params):
        rde_time = jnp.clip(time, self.t0, self.tf)
        rde_time = -rde_time

        Sxx, sx = self.sol_rde.evaluate(rde_time)

        x_d = self.func_x_d(time)
        u_d = self.func_u_d(time)

        x_0 = self.func_x_0(time)
        u_0 = self.func_u_0(time)

        x_d_0 = x_d - x_0
        u_d_0 = u_d - u_0

        B = self.get_B(x_0, u_0, time)

        R = self.func_R(time)
        N = self.func_N(time)
        Rinv = jnp.linalg.inv(R)

        ru = -jnp.dot(R, u_d_0) - jnp.dot(N.T, x_d_0)

        Rinv = jnp.linalg.inv(R)
        N_plus_Sxx_B = N + jnp.matmul(Sxx, B)

        u = (
            u_0
            - jnp.dot(Rinv, jnp.dot(N_plus_Sxx_B.T, (x - x_0)))
            - jnp.dot(Rinv, ru + jnp.dot(B.T, sx))
        )

        return u

Gain

Bases: FeedthroughBlock

Multiply the input signal by a constant value.

Input ports

(0) The input signal.

Output ports

(0) The input signal multiplied by the gain: y = gain * u.

Parameters:

Name Type Description Default
gain

The value to scale the input signal by.

required
Source code in collimator/library/primitives.py
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
class Gain(FeedthroughBlock):
    """Multiply the input signal by a constant value.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The input signal multiplied by the gain: `y = gain * u`.

    Parameters:
        gain:
            The value to scale the input signal by.
    """

    @parameters(dynamic=["gain"])
    def __init__(self, gain, *args, **kwargs):
        super().__init__(lambda x, gain: gain * x, *args, **kwargs)

    def initialize(self, gain):
        pass

HermiteSimpsonNMPC

Bases: NonlinearMPCIpopt

Implementation of nonlinear MPC with Hermite-Simpson collocation and IPOPT as the NLP solver.

Input ports

(0) x_0 : current state vector. (1) x_ref : reference state trajectory for the nonlinear MPC. (2) u_ref : reference input trajectory for the nonlinear MPC.

Output ports

(1) u_opt : the optimal control input to be applied at the current time step as determined by the nonlinear MPC.

Parameters:

Name Type Description Default
plant

LeafSystem or Diagram The plant to be controlled.

required
Q

Array State weighting matrix in the cost function.

required
QN

Array Terminal state weighting matrix in the cost function.

required
R

Array Control input weighting matrix in the cost function.

required
N

int The prediction horizon, an integer specifying the number of steps to predict. Note: prediction and control horizons are identical for now.

required
dt

float: Major time step, a scalar indicating the increment in time for each step in the prediction and control horizons.

required
lb_x

Array Lower bound on the state vector.

None
ub_x

Array Upper bound on the state vector.

None
lb_u

Array Lower bound on the control input vector.

None
ub_u

Array Upper bound on the control input vector.

None
include_terminal_x_as_constraint

bool If True, the terminal state is included as a constraint in the NLP.

False
include_terminal_u_as_constraint

bool If True, the terminal control input is included as a constraint in the NLP.

False
x_optvars_0

Array Initial guess for the state vector optimization variables in the NLP.

None
u_optvars_0

Array Initial guess for the control vector optimization variables in the NLP.

None
Source code in collimator/library/nmpc/hermite_simpson_ipopt_nmpc.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
class HermiteSimpsonNMPC(NonlinearMPCIpopt):
    """
    Implementation of nonlinear MPC with Hermite-Simpson collocation and IPOPT as the
    NLP solver.

    Input ports:
        (0) x_0 : current state vector.
        (1) x_ref : reference state trajectory for the nonlinear MPC.
        (2) u_ref : reference input trajectory for the nonlinear MPC.

    Output ports:
        (1) u_opt : the optimal control input to be applied at the current time step
                    as determined by the nonlinear MPC.

    Parameters:
        plant: LeafSystem or Diagram
            The plant to be controlled.

        Q: Array
            State weighting matrix in the cost function.

        QN: Array
            Terminal state weighting matrix in the cost function.

        R: Array
            Control input weighting matrix in the cost function.

        N: int
            The prediction horizon, an integer specifying the number of steps to
            predict. Note: prediction and control horizons are identical for now.

        dt: float:
            Major time step, a scalar indicating the increment in time for each step in
            the prediction and control horizons.

        lb_x: Array
            Lower bound on the state vector.

        ub_x: Array
            Upper bound on the state vector.

        lb_u: Array
            Lower bound on the control input vector.

        ub_u: Array
            Upper bound on the control input vector.

        include_terminal_x_as_constraint: bool
            If True, the terminal state is included as a constraint in the NLP.

        include_terminal_u_as_constraint: bool
            If True, the terminal control input is included as a constraint in the NLP.

        x_optvars_0: Array
            Initial guess for the state vector optimization variables in the NLP.

        u_optvars_0: Array
            Initial guess for the control vector optimization variables in the NLP.
    """

    def __init__(
        self,
        plant,
        Q,
        QN,
        R,
        N,
        dt,
        lb_x=None,
        ub_x=None,
        lb_u=None,
        ub_u=None,
        include_terminal_x_as_constraint=False,
        include_terminal_u_as_constraint=False,
        x_optvars_0=None,
        u_optvars_0=None,
        name=None,
    ):
        self.Q = Q
        self.QN = QN
        self.R = R

        self.N = N
        self.dt = dt

        self.lb_x = lb_x
        self.ub_x = ub_x
        self.lb_u = lb_u
        self.ub_u = ub_u

        self.include_terminal_x_as_constraint = include_terminal_x_as_constraint
        self.include_terminal_u_as_constraint = include_terminal_u_as_constraint

        self.nx = Q.shape[0]
        self.nu = R.shape[0]

        if lb_x is None:
            self.lb_x = -1e20 * jnp.ones(self.nx)

        if ub_x is None:
            self.ub_x = 1e20 * jnp.ones(self.nx)

        if lb_u is None:
            self.lb_u = -1e20 * jnp.ones(self.nu)

        if ub_u is None:
            self.ub_u = 1e20 * jnp.ones(self.nu)

        # Currently guesses are not taken into account
        self.x_optvars_0 = x_optvars_0  # Currently does nothing
        self.u_optvars_0 = u_optvars_0  # Currently does nothing
        if x_optvars_0 is None:
            x_optvars_0 = jnp.zeros((N + 1, self.nx))
        if u_optvars_0 is None:
            u_optvars_0 = jnp.zeros((N + 1, self.nu))

        self.ode_rhs = make_ode_rhs(plant, self.nu)

        nlp_structure_ipopt = NMPCProblemStructure(
            self.num_optvars,
            self._objective,
            self._constraints,
        )

        super().__init__(
            dt,
            self.nu,
            self.num_optvars,
            nlp_structure_ipopt,
            name=name,
        )

    @property
    def num_optvars(self):
        return (self.N + 1) * (self.nx + self.nu)

    @property
    def num_constraints(self):
        # max size regardless of terminal constraints (for jit compilation)
        num_contraints = (self.N + 2) * self.nx + self.nu
        return num_contraints

    @property
    def bounds_optvars(self):
        lb = jnp.hstack(
            [jnp.tile(self.lb_u, self.N + 1), jnp.tile(self.lb_x, self.N + 1)]
        )
        ub = jnp.hstack(
            [jnp.tile(self.ub_u, self.N + 1), jnp.tile(self.ub_x, self.N + 1)]
        )
        return (lb, ub)

    @property
    def bounds_constraints(self):
        c_lb = jnp.zeros(self.num_constraints)
        c_ub = jnp.zeros(self.num_constraints)
        return (c_lb, c_ub)

    @partial(jax.jit, static_argnames=("self",))
    def _objective(self, optvars, t0, x0, x_ref, u_ref):
        u_and_x_flat = optvars

        u = u_and_x_flat[: self.nu * (self.N + 1)].reshape((self.N + 1, self.nu))
        x = u_and_x_flat[self.nu * (self.N + 1) :].reshape((self.N + 1, self.nx))

        xdiff = x - x_ref
        udiff = u - u_ref

        # compute sum of quadratic products for x_0 to x_{n-1}
        A = jnp.dot(xdiff[:-1], self.Q)
        qp_x_sum = jnp.sum(xdiff[:-1] * A, axis=None)

        # Compute quadratic product for the x_N
        xN = xdiff[-1]
        qp_x_N = jnp.dot(xN, jnp.dot(self.QN, xN))

        # compute sum of quadratic products for u_0 to u_{n-1}
        B = jnp.dot(udiff, self.R)
        qp_u_sum = jnp.sum(udiff * B, axis=None)

        # Sum the quadratic products
        total_sum = qp_x_sum + qp_x_N + qp_u_sum
        return total_sum

    @partial(jax.jit, static_argnames=("self",))
    def _constraints(self, optvars, t0, x0, x_ref, u_ref):
        u_and_x_flat = optvars

        u = u_and_x_flat[: self.nu * (self.N + 1)].reshape((self.N + 1, self.nu))
        x = u_and_x_flat[self.nu * (self.N + 1) :].reshape((self.N + 1, self.nx))

        h = self.dt
        t = t0 + h * jnp.arange(self.N + 1)

        dot_x = jnp.zeros((self.N + 1, self.nx))

        def loop_body_break(idx, dot_x):
            rhs = self.ode_rhs(x[idx], u[idx], t[idx])
            dot_x = dot_x.at[idx].set(rhs)
            return dot_x

        dot_x = jax.lax.fori_loop(0, self.N + 1, loop_body_break, dot_x)

        t = t0 + self.dt * jnp.arange(self.N + 1)
        t_c = 0.5 * (t[:-1] + t[1:])
        u_c = 0.5 * (u[:-1] + u[1:])
        x_c = 0.5 * (x[:-1] + x[1:]) + (h / 8.0) * (dot_x[:-1] - dot_x[1:])

        dot_x_c = (-3.0 / 2.0 / h) * (x[:-1] - x[1:]) - (1.0 / 4.0) * (
            dot_x[:-1] + dot_x[1:]
        )

        c0 = x0 - x[0]

        c_others = jnp.zeros((self.N, self.nx))

        def loop_body_colloc(idx, c_others):
            c_colocation = self.ode_rhs(x_c[idx], u_c[idx], t_c[idx]) - dot_x_c[idx]
            c_others = c_others.at[idx].set(c_colocation)
            return c_others

        c_others = jax.lax.fori_loop(0, self.N, loop_body_colloc, c_others)
        c_all = jnp.hstack([c0.ravel(), c_others.ravel()])

        c_terminal_x = x_ref[self.N] - x[self.N]
        c_terminal_u = u_ref[self.N] - u[self.N]

        c_all = cond(
            self.include_terminal_x_as_constraint,
            lambda c_all, c_terminal_x: jnp.hstack([c_all, c_terminal_x.ravel()]),
            lambda c_all, c_terminal_x: jnp.hstack([c_all, jnp.zeros(self.nx)]),
            c_all,
            c_terminal_x,
        )

        c_all = cond(
            self.include_terminal_u_as_constraint,
            lambda c_all, c_terminal_u: jnp.hstack([c_all, c_terminal_u.ravel()]),
            lambda c_all, c_terminal_x: jnp.hstack([c_all, jnp.zeros(self.nu)]),
            c_all,
            c_terminal_u,
        )

        return c_all

IOPort

Bases: FeedthroughBlock

Simple class for organizing input/output ports for groups/submodels.

Since these are treated as standalone blocks in the UI rather than specific input/output ports exported to the parent model, it is more straightforward to represent them that way here as well.

This class represents a simple one-input, one-output feedthrough block where the feedthrough function is an identity. The input (resp. output) port can then be exported to the parent model to create an Inport (resp. Outport).

Source code in collimator/library/primitives.py
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
class IOPort(FeedthroughBlock):
    """Simple class for organizing input/output ports for groups/submodels.

    Since these are treated as standalone blocks in the UI rather than specific
    input/output ports exported to the parent model, it is more straightforward
    to represent them that way here as well.

    This class represents a simple one-input, one-output feedthrough block where
    the feedthrough function is an identity.  The input (resp. output) port can then
    be exported to the parent model to create an Inport (resp. Outport).
    """

    def __init__(self, *args, **kwargs):
        super().__init__(lambda x: x, *args, **kwargs)

IfThenElse

Bases: LeafSystem

Applies a conditional expression to the input signals.

Given inputs pred, true_val, and false_val, the block computes:

y = true_val if pred else false_val

The true and false values may be any arrays, but must have the same shape and dtype.

Input ports

(0) The boolean predicate. (1) The true value. (2) The false value.

Output ports

(0) The result of the conditional expression. Shape and dtype will match the true and false values.

Events

An event is triggered when the output changes from true to false or vice versa.

Source code in collimator/library/primitives.py
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
class IfThenElse(LeafSystem):
    """Applies a conditional expression to the input signals.

    Given inputs `pred`, `true_val`, and `false_val`, the block computes:
    ```
    y = true_val if pred else false_val
    ```

    The true and false values may be any arrays, but must have the same
    shape and dtype.

    Input ports:
        (0) The boolean predicate.
        (1) The true value.
        (2) The false value.

    Output ports:
        (0) The result of the conditional expression. Shape and dtype will match
            the true and false values.

    Events:
        An event is triggered when the output changes from true to false or vice versa.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.declare_input_port()  # pred
        self.declare_input_port()  # true_val
        self.declare_input_port()  # false_val

        def _compute_output(_time, _state, *inputs, **_params):
            return cnp.where(inputs[0], inputs[1], inputs[2])

        self.declare_output_port(
            _compute_output,
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
        )

    def _edge_detection(self, _time, _state, *inputs, **_params):
        return cnp.where(inputs[0], 1.0, -1.0)

    def initialize_static_data(self, context):
        # Add a zero-crossing event so ODE solvers can't try to integrate
        # through a discontinuity. For efficiency, only do this if the output is
        # fed to an ODE.
        if not self.has_zero_crossing_events and is_discontinuity(self.output_ports[0]):
            self.declare_zero_crossing(self._edge_detection, direction="crosses_zero")

        return super().initialize_static_data(context)

InfiniteHorizonKalmanFilter

Bases: KalmanFilterBase

Infinite Horizon Kalman Filter for the following system:

x[n+1] = A x[n] + B u[n] + G w[n]
y[n]   = C x[n] + D u[n] + v[n]

E(w[n]) = E(v[n]) = 0
E(w[n]w'[n]) = Q
E(v[n]v'[n]) = R
E(w[n]v'[n]) = N = 0
Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
dt

float Time step of the discrete-time system

required
A

ndarray State transition matrix

required
B

ndarray Input matrix

required
C

ndarray Output matrix. If None, full state output is assumed.

None
D

ndarray Feedthrough matrix. If None, no feedthrough is assumed.

None
G

ndarray Process noise matrix. If None, G=B is assumed.

None
Q

ndarray Process noise covariance matrix. If None, Identity matrix of size compatible with G and A is assumed.

None
R

ndarray Measurement noise covariance matrix. If None, Identity matrix of size compatible with C and A is assumed.

None
x_hat_0

ndarray Initial state estimate. If None, an array of zeros is assumed.

None
Source code in collimator/library/state_estimators/infinite_horizon_kalman_filter.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
class InfiniteHorizonKalmanFilter(KalmanFilterBase):
    """
    Infinite Horizon Kalman Filter for the following system:

    ```
    x[n+1] = A x[n] + B u[n] + G w[n]
    y[n]   = C x[n] + D u[n] + v[n]

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Q
    E(v[n]v'[n]) = R
    E(w[n]v'[n]) = N = 0
    ```

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        dt: float
            Time step of the discrete-time system
        A: ndarray
            State transition matrix
        B: ndarray
            Input matrix
        C: ndarray
            Output matrix. If `None`, full state output is assumed.
        D: ndarray
            Feedthrough matrix. If `None`, no feedthrough is assumed.
        G: ndarray
            Process noise matrix. If `None`, `G=B` is assumed.
        Q: ndarray
            Process noise covariance matrix. If `None`, Identity matrix of size
            compatible with `G` and `A` is assumed.
        R: ndarray
            Measurement noise covariance matrix. If `None`, Identity matrix of size
            compatible with `C` and `A` is assumed.
        x_hat_0: ndarray
            Initial state estimate. If `None`, an array of zeros is assumed.
    """

    @parameters(
        static=["dt", "A", "B", "C", "D", "G", "Q", "R", "x_hat_0"],
    )
    def __init__(
        self,
        dt,
        A,
        B,
        C=None,
        D=None,
        G=None,
        Q=None,
        R=None,
        x_hat_0=None,
        name=None,
        **kwargs,
    ):
        self.nx = 0
        self.nu = 0
        self.ny = 0
        self.nd = 0
        self.A = None
        self.B = None
        self.C = None
        self.D = None
        self.G = None
        self.Q = None
        self.R = None
        self.K = None
        self.A_minus_LC = None
        self.B_minus_LD = None
        self.L = None

        # Note: This class inherits from KalmanFilterBase. Since the infinite horizon
        # kalman filter does not need P_hat and track it, a dummy_P_hat_0 is set as
        # Identity matrix of size 1, and used wherever KalmanFilterBase demands
        # P_hat-like matrices
        self.dummy_P_hat_0 = cnp.eye(1)
        is_feedthrough = False if D is None else bool(not cnp.allclose(D, 0.0))
        super().__init__(
            dt, x_hat_0, self.dummy_P_hat_0, is_feedthrough, name, **kwargs
        )

    def initialize(
        self, dt, A, B, C=None, D=None, G=None, Q=None, R=None, x_hat_0=None
    ):
        self.nx, self.nu = B.shape

        if C is None:
            C = cnp.eye(self.nx)
            self.ny = self.nx
        else:
            self.ny = C.shape[0]

        if D is None:
            D = cnp.zeros((self.ny, self.nu))
        self.is_feedthrough = bool(not cnp.allclose(D, 0.0))

        if G is None:
            G = B

        _, self.nd = G.shape

        if Q is None:
            Q = cnp.eye(self.nd)

        if R is None:
            R = cnp.eye(self.ny)

        if x_hat_0 is None:
            x_hat_0 = cnp.zeros(self.nx)

        check_shape_compatibilities(A, B, C, D, G, Q, R)

        self.A = A
        self.B = B
        self.C = C
        self.D = D
        self.G = G
        self.Q = Q
        self.R = R

        L, P, E = control.dlqe(A, G, C, Q, R)

        self.K = np.linalg.solve(A, L)

        self.A_minus_LC = A - np.matmul(L, C)
        self.B_minus_LD = B - np.matmul(L, D)
        self.L = L

    def _correct(self, time, x_hat_minus, P_hat_minus, *inputs):
        u, y = inputs
        y = cnp.atleast_1d(y)

        C, D = self.C, self.D

        x_hat_plus = x_hat_minus + cnp.dot(self.K, y - cnp.dot(C, x_hat_minus))  # n|n

        if self.is_feedthrough:
            u = cnp.atleast_1d(u)
            x_hat_plus = x_hat_plus - cnp.dot(self.K, cnp.dot(D, u))

        return x_hat_plus, self.dummy_P_hat_0

    def _propagate(self, time, x_hat_plus, P_hat_plus, *inputs):
        u, y = inputs
        u = cnp.atleast_1d(u)

        x_hat_minus = (
            cnp.dot(self.A_minus_LC, x_hat_plus)
            + cnp.dot(self.B_minus_LD, u)
            + cnp.dot(self.L, y)
        )  # n+1|n

        return x_hat_minus, self.dummy_P_hat_0

    #######################################
    # Make filter for a continuous plant  #
    #######################################

    @staticmethod
    @with_resolved_parameters
    def for_continuous_plant(
        plant,
        x_eq,
        u_eq,
        dt,
        Q=None,
        R=None,
        G=None,
        x_hat_bar_0=None,
        discretization_method="zoh",
        discretized_noise=False,
        name=None,
        ui_id=None,
    ):
        """
        Obtain an Infinite Horizon Kalman Filter system for a continuous-time plant
        after linearization at equilibrium point (x_eq, u_eq)

        The input plant contains the deterministic forms of the forward and observation
        operators:

        ```
            dx/dt = f(x,u)
            y = g(x,u)
        ```

        Note: (i) Only plants with one vector-valued input and one vector-valued output
        are currently supported. Furthermore, the plant LeafSystem/Diagram should have
        only one vector-valued integrator. (ii) the user may pass a plant with
        disturbances as the input plant. However, computation of `y_eq` will be fraught
        with disturbances.

        A plant with disturbances of the following form is then considered
        following form:

        ```
            dx/dt = f(x,u) + G w                        --- (C1)
            y = g(x,u) +  v                             --- (C2)
        ```

        where:

            `w` represents the process noise,
            `v` represents the measurement noise,

        and

        ```
            E(w) = E(v) = 0
            E(ww') = Q
            E(vv') = R
            E(wv') = N = 0
        ```

        This plant with disturbances is linearized (only `f` and `g`) around the
        equilibrium point to obtain:

        ```
            d/dt (x_bar) = A x_bar + B u_bar + G w
            y_bar = C x_bar + D u_bar + v
        ```

        where,

        ```
            x_bar = x - x_eq
            u_bar = u - u_eq
            y_bar = y - y_bar
            y_eq = g(x_eq, u_eq)
        ```

        The linearized plant is then discretized via `euler` or `zoh` method to obtain:

        ```
            x_bar[n] = Ad x_bar[n] + Bd u_bar[n] + Gd w[n]           --- (L1)
            y_bar[n] = Cd x_bar[n] + Dd u_bar[n] + v[n]              --- (L2)

            E(w[n]) = E(v[n]) = 0
            E(w[n]w'[n]) = Qd
            E(v[n]v'[n]) = Rd
            E(w[n]v'[n]) = Nd = 0
        ```

        Note: If `discretized_noise` is True, then it is assumed that the user is
        providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
        continuous-time Q, R, and G, and Gd is set to Identity matrix.

        An Infinite Horizon Kalman Filter estimator for the system of equations (L1)
        and (L2) is returned. This filter is in the `x_bar`, `u_bar`, and `y_bar`
        states.

        This returned system will have

        Input ports:
            (0) u_bar[n] : control vector at timestep n, relative to equilibrium
            (1) y_bar[n] : measurement vector at timestep n, relative to equilibrium

        Output ports:
            (1) x_hat_bar[n] : state vector estimate at timestep n, relative to
                               equilibrium

        Parameters:
            plant : a `Plant` object which can be a LeafSystem or a Diagram.
            x_eq: ndarray
                Equilibrium state vector for discretization
            u_eq: ndarray
                Equilibrium control vector for discretization
            dt: float
                Time step for the discretization.
            Q: ndarray
                Process noise covariance matrix. If `None`, Identity matrix of size
                compatible with `G` and and linearized system's `A` is assumed.
            R: ndarray
                Measurement noise covariance matrix. If `None`, Identity matrix of size
                compatible with linearized system's `C` and `A` is assumed.
            G: ndarray
                Process noise matrix. If `None`, `G=B` is assumed making disrurbances
                additive to control vector `u`, i.e. `u_disturbed = u_orig + w`.
            x_hat_bar_0: ndarray
                Initial state estimate relative to equilibrium.
                If None, an identity matrix is assumed.
            discretization_method: str ("euler" or "zoh")
                Method to discretize the continuous-time plant. Default is "euler".
            discretized_noise: bool
                Whether the user is directly providing Gd, Qd and Rd. Default is False.
                If True, `G`, `Q`, and `R` are assumed to be Gd, Qd, and Rd,
                respectively.
        """
        (
            y_eq,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
        ) = linearize_and_discretize_continuous_plant(
            plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
        )

        check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

        nx = x_eq.size

        if x_hat_bar_0 is None:
            x_hat_bar_0 = cnp.zeros(nx)

        # Instantiate an Infinite Horizon Kalman Filter for the linearized plant
        kf = InfiniteHorizonKalmanFilter(
            dt,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
            x_hat_bar_0,
            name=name,
            ui_id=ui_id,
        )

        return y_eq, kf

    ##############################################
    # Make global filter for a continuous plant  #
    ##############################################

    @staticmethod
    @with_resolved_parameters
    def global_filter_for_continuous_plant(
        plant,
        x_eq,
        u_eq,
        dt,
        Q=None,
        R=None,
        G=None,
        x_hat_0=None,
        discretization_method="euler",
        discretized_noise=False,
        name=None,
        ui_id=None,
    ):
        """
        See docs for `for_continuous_plant`, which returns the local infinite horizon
        Kalman Filter. This method additionally converts the local Kalman Filter to a
        global estimator. See docs for `make_global_estimator_from_local` for details.
        """
        (
            y_eq,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
        ) = linearize_and_discretize_continuous_plant(
            plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
        )

        check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

        nx = x_eq.size

        if x_hat_0 is None:
            x_hat_bar_0 = cnp.zeros(nx)
        else:
            x_hat_bar_0 = x_hat_0 - x_eq

        # Instantiate an Infinite Horizon Kalman Filter for the linearized plant
        local_kf = InfiniteHorizonKalmanFilter(
            dt,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
            x_hat_bar_0,
            name=name,
            ui_id=ui_id,
        )

        global_kf = make_global_estimator_from_local(
            local_kf,
            x_eq,
            u_eq,
            y_eq,
            name=name,
            ui_id=ui_id,
        )

        return global_kf

for_continuous_plant(plant, x_eq, u_eq, dt, Q=None, R=None, G=None, x_hat_bar_0=None, discretization_method='zoh', discretized_noise=False, name=None, ui_id=None) staticmethod

Obtain an Infinite Horizon Kalman Filter system for a continuous-time plant after linearization at equilibrium point (x_eq, u_eq)

The input plant contains the deterministic forms of the forward and observation operators:

    dx/dt = f(x,u)
    y = g(x,u)

Note: (i) Only plants with one vector-valued input and one vector-valued output are currently supported. Furthermore, the plant LeafSystem/Diagram should have only one vector-valued integrator. (ii) the user may pass a plant with disturbances as the input plant. However, computation of y_eq will be fraught with disturbances.

A plant with disturbances of the following form is then considered following form:

    dx/dt = f(x,u) + G w                        --- (C1)
    y = g(x,u) +  v                             --- (C2)

where:

`w` represents the process noise,
`v` represents the measurement noise,

and

    E(w) = E(v) = 0
    E(ww') = Q
    E(vv') = R
    E(wv') = N = 0

This plant with disturbances is linearized (only f and g) around the equilibrium point to obtain:

    d/dt (x_bar) = A x_bar + B u_bar + G w
    y_bar = C x_bar + D u_bar + v

where,

    x_bar = x - x_eq
    u_bar = u - u_eq
    y_bar = y - y_bar
    y_eq = g(x_eq, u_eq)

The linearized plant is then discretized via euler or zoh method to obtain:

    x_bar[n] = Ad x_bar[n] + Bd u_bar[n] + Gd w[n]           --- (L1)
    y_bar[n] = Cd x_bar[n] + Dd u_bar[n] + v[n]              --- (L2)

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Qd
    E(v[n]v'[n]) = Rd
    E(w[n]v'[n]) = Nd = 0

Note: If discretized_noise is True, then it is assumed that the user is providing Gd, Qd and Rd. If False, then Qd and Rd are computed from continuous-time Q, R, and G, and Gd is set to Identity matrix.

An Infinite Horizon Kalman Filter estimator for the system of equations (L1) and (L2) is returned. This filter is in the x_bar, u_bar, and y_bar states.

This returned system will have

Input ports

(0) u_bar[n] : control vector at timestep n, relative to equilibrium (1) y_bar[n] : measurement vector at timestep n, relative to equilibrium

Output ports

(1) x_hat_bar[n] : state vector estimate at timestep n, relative to equilibrium

Parameters:

Name Type Description Default
plant

a Plant object which can be a LeafSystem or a Diagram.

required
x_eq

ndarray Equilibrium state vector for discretization

required
u_eq

ndarray Equilibrium control vector for discretization

required
dt

float Time step for the discretization.

required
Q

ndarray Process noise covariance matrix. If None, Identity matrix of size compatible with G and and linearized system's A is assumed.

None
R

ndarray Measurement noise covariance matrix. If None, Identity matrix of size compatible with linearized system's C and A is assumed.

None
G

ndarray Process noise matrix. If None, G=B is assumed making disrurbances additive to control vector u, i.e. u_disturbed = u_orig + w.

None
x_hat_bar_0

ndarray Initial state estimate relative to equilibrium. If None, an identity matrix is assumed.

None
discretization_method

str ("euler" or "zoh") Method to discretize the continuous-time plant. Default is "euler".

'zoh'
discretized_noise

bool Whether the user is directly providing Gd, Qd and Rd. Default is False. If True, G, Q, and R are assumed to be Gd, Qd, and Rd, respectively.

False
Source code in collimator/library/state_estimators/infinite_horizon_kalman_filter.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
@staticmethod
@with_resolved_parameters
def for_continuous_plant(
    plant,
    x_eq,
    u_eq,
    dt,
    Q=None,
    R=None,
    G=None,
    x_hat_bar_0=None,
    discretization_method="zoh",
    discretized_noise=False,
    name=None,
    ui_id=None,
):
    """
    Obtain an Infinite Horizon Kalman Filter system for a continuous-time plant
    after linearization at equilibrium point (x_eq, u_eq)

    The input plant contains the deterministic forms of the forward and observation
    operators:

    ```
        dx/dt = f(x,u)
        y = g(x,u)
    ```

    Note: (i) Only plants with one vector-valued input and one vector-valued output
    are currently supported. Furthermore, the plant LeafSystem/Diagram should have
    only one vector-valued integrator. (ii) the user may pass a plant with
    disturbances as the input plant. However, computation of `y_eq` will be fraught
    with disturbances.

    A plant with disturbances of the following form is then considered
    following form:

    ```
        dx/dt = f(x,u) + G w                        --- (C1)
        y = g(x,u) +  v                             --- (C2)
    ```

    where:

        `w` represents the process noise,
        `v` represents the measurement noise,

    and

    ```
        E(w) = E(v) = 0
        E(ww') = Q
        E(vv') = R
        E(wv') = N = 0
    ```

    This plant with disturbances is linearized (only `f` and `g`) around the
    equilibrium point to obtain:

    ```
        d/dt (x_bar) = A x_bar + B u_bar + G w
        y_bar = C x_bar + D u_bar + v
    ```

    where,

    ```
        x_bar = x - x_eq
        u_bar = u - u_eq
        y_bar = y - y_bar
        y_eq = g(x_eq, u_eq)
    ```

    The linearized plant is then discretized via `euler` or `zoh` method to obtain:

    ```
        x_bar[n] = Ad x_bar[n] + Bd u_bar[n] + Gd w[n]           --- (L1)
        y_bar[n] = Cd x_bar[n] + Dd u_bar[n] + v[n]              --- (L2)

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Qd
        E(v[n]v'[n]) = Rd
        E(w[n]v'[n]) = Nd = 0
    ```

    Note: If `discretized_noise` is True, then it is assumed that the user is
    providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
    continuous-time Q, R, and G, and Gd is set to Identity matrix.

    An Infinite Horizon Kalman Filter estimator for the system of equations (L1)
    and (L2) is returned. This filter is in the `x_bar`, `u_bar`, and `y_bar`
    states.

    This returned system will have

    Input ports:
        (0) u_bar[n] : control vector at timestep n, relative to equilibrium
        (1) y_bar[n] : measurement vector at timestep n, relative to equilibrium

    Output ports:
        (1) x_hat_bar[n] : state vector estimate at timestep n, relative to
                           equilibrium

    Parameters:
        plant : a `Plant` object which can be a LeafSystem or a Diagram.
        x_eq: ndarray
            Equilibrium state vector for discretization
        u_eq: ndarray
            Equilibrium control vector for discretization
        dt: float
            Time step for the discretization.
        Q: ndarray
            Process noise covariance matrix. If `None`, Identity matrix of size
            compatible with `G` and and linearized system's `A` is assumed.
        R: ndarray
            Measurement noise covariance matrix. If `None`, Identity matrix of size
            compatible with linearized system's `C` and `A` is assumed.
        G: ndarray
            Process noise matrix. If `None`, `G=B` is assumed making disrurbances
            additive to control vector `u`, i.e. `u_disturbed = u_orig + w`.
        x_hat_bar_0: ndarray
            Initial state estimate relative to equilibrium.
            If None, an identity matrix is assumed.
        discretization_method: str ("euler" or "zoh")
            Method to discretize the continuous-time plant. Default is "euler".
        discretized_noise: bool
            Whether the user is directly providing Gd, Qd and Rd. Default is False.
            If True, `G`, `Q`, and `R` are assumed to be Gd, Qd, and Rd,
            respectively.
    """
    (
        y_eq,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
    ) = linearize_and_discretize_continuous_plant(
        plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
    )

    check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

    nx = x_eq.size

    if x_hat_bar_0 is None:
        x_hat_bar_0 = cnp.zeros(nx)

    # Instantiate an Infinite Horizon Kalman Filter for the linearized plant
    kf = InfiniteHorizonKalmanFilter(
        dt,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
        x_hat_bar_0,
        name=name,
        ui_id=ui_id,
    )

    return y_eq, kf

global_filter_for_continuous_plant(plant, x_eq, u_eq, dt, Q=None, R=None, G=None, x_hat_0=None, discretization_method='euler', discretized_noise=False, name=None, ui_id=None) staticmethod

See docs for for_continuous_plant, which returns the local infinite horizon Kalman Filter. This method additionally converts the local Kalman Filter to a global estimator. See docs for make_global_estimator_from_local for details.

Source code in collimator/library/state_estimators/infinite_horizon_kalman_filter.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
@staticmethod
@with_resolved_parameters
def global_filter_for_continuous_plant(
    plant,
    x_eq,
    u_eq,
    dt,
    Q=None,
    R=None,
    G=None,
    x_hat_0=None,
    discretization_method="euler",
    discretized_noise=False,
    name=None,
    ui_id=None,
):
    """
    See docs for `for_continuous_plant`, which returns the local infinite horizon
    Kalman Filter. This method additionally converts the local Kalman Filter to a
    global estimator. See docs for `make_global_estimator_from_local` for details.
    """
    (
        y_eq,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
    ) = linearize_and_discretize_continuous_plant(
        plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
    )

    check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

    nx = x_eq.size

    if x_hat_0 is None:
        x_hat_bar_0 = cnp.zeros(nx)
    else:
        x_hat_bar_0 = x_hat_0 - x_eq

    # Instantiate an Infinite Horizon Kalman Filter for the linearized plant
    local_kf = InfiniteHorizonKalmanFilter(
        dt,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
        x_hat_bar_0,
        name=name,
        ui_id=ui_id,
    )

    global_kf = make_global_estimator_from_local(
        local_kf,
        x_eq,
        u_eq,
        y_eq,
        name=name,
        ui_id=ui_id,
    )

    return global_kf

Integrator

Bases: LeafSystem

Integrate the input signal in time.

The Integrator block is the main primitive for building continuous-time models. It is a first-order integrator, implementing the following linear time-invariant ordinary differential equation for input values u and output values y:

    ẋ = u
    y = x

where x is the state of the integrator. The integrator is initialized with the value of the initial_state parameter.

Options

Reset: the integrator can be configured to reset its state on an input trigger. The reset value can be either the initial state of the integrator or an external value provided by an input port. Limits: the integrator can be configured such that the output and state are constrained by upper and lower limits. Hold: the integrator can be configured to hold integration based on an input trigger.

The Integrator block is also designed to detect "Zeno" behavior, where the reset events happen asymptotically closer together. This is a pathological case that can cause numerical issues in the simulation and should typically be avoided by introducing some physically realistic hysteresis into the model. However, in the event that Zeno behavior is unavoidable, the integrator will enter a "Zeno" state where the output is held constant until the trigger changes value to False. See the "bouncing ball" demo for a Zeno example.

Input ports

(0) The input signal. Must match the shape and dtype of the initial continuous state. (1) The reset trigger. Optional, only if enable_reset is True. (2) The reset value. Optional, only if enable_external_reset is True. (3) The hold trigger. Optional, only if 'enable_hold' is True.

Output ports

(0) The continuous state of the integrator.

Parameters:

Name Type Description Default
initial_state

The initial value of the integrator state. Can be any array, or even a nested structure of arrays, but the data type should be floating-point.

required
enable_reset

If True, the integrator will reset its state to the initial value when the reset trigger is True. Adds an additional input port for the reset trigger. This signal should be boolean- or binary-valued.

False
enable_external_reset

If True, the integrator will reset its state to the value provided by the reset value input port when the reset trigger is True. Otherwise, the integrator will reset to the initial value. Adds an additional input port for the reset value. This signal should match the shape and dtype of the initial continuous state.

False
enable_limits

If True, the integrator will constrain its state and output to within the upper and lower limits. Either limit may be disbale by setting its value to None.

False
enable_hold

If True, the integrator will hold integration when the hold trigger is True.

False
reset_on_enter_zeno

If True, the integrator will reset its state to the initial value when the integrator enters the Zeno state. This option is ignored unless enable_reset is True.

False
zeno_tolerance

The tolerance used to determine if the integrator is in the Zeno state. If the time between events is less than this tolerance, then the integrator is in the Zeno state. This option is ignored unless enable_reset is True.

1e-06
Events

An event is triggered when the "reset" port changes.

An event is triggered when the state hit one of the limits.

An event is triggered when the "hold" port changes.

Another guard is conditionally active when the integrator is in the Zeno state, and is triggered when the "reset" port changes from True to False. This event is used to exit the Zeno state and resume normal integration.

Source code in collimator/library/primitives.py
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
class Integrator(LeafSystem):
    """Integrate the input signal in time.

    The Integrator block is the main primitive for building continuous-time
    models.  It is a first-order integrator, implementing the following linear
    time-invariant ordinary differential equation for input values `u` and output
    values `y`:
    ```
        ẋ = u
        y = x
    ```
    where `x` is the state of the integrator.  The integrator is initialized
    with the value of the `initial_state` parameter.

    Options:
        Reset: the integrator can be configured to reset its state on an input
            trigger.  The reset value can be either the initial state of the
            integrator or an external value provided by an input port.
        Limits: the integrator can be configured such that the output and state
            are constrained by upper and lower limits.
        Hold: the integrator can be configured to hold integration based on an
            input trigger.

    The Integrator block is also designed to detect "Zeno" behavior, where the
    reset events happen asymptotically closer together.  This is a pathological
    case that can cause numerical issues in the simulation and should typically be
    avoided by introducing some physically realistic hysteresis into the model.
    However, in the event that Zeno behavior is unavoidable, the integrator will
    enter a "Zeno" state where the output is held constant until the trigger
    changes value to False.  See the "bouncing ball" demo for a Zeno example.

    Input ports:
        (0) The input signal.  Must match the shape and dtype of the initial
            continuous state.
        (1) The reset trigger.  Optional, only if `enable_reset` is True.
        (2) The reset value.  Optional, only if `enable_external_reset` is True.
        (3) The hold trigger. Optional, only if 'enable_hold' is True.

    Output ports:
        (0) The continuous state of the integrator.

    Parameters:
        initial_state:
            The initial value of the integrator state.  Can be any array, or even
            a nested structure of arrays, but the data type should be floating-point.
        enable_reset:
            If True, the integrator will reset its state to the initial value
            when the reset trigger is True.  Adds an additional input port for
            the reset trigger.  This signal should be boolean- or binary-valued.
        enable_external_reset:
            If True, the integrator will reset its state to the value provided
            by the reset value input port when the reset trigger is True. Otherwise,
            the integrator will reset to the initial value.  Adds an additional
            input port for the reset value.  This signal should match the shape
            and dtype of the initial continuous state.
        enable_limits:
            If True, the integrator will constrain its state and output to within
            the upper and lower limits. Either limit may be disbale by setting its
            value to None.
        enable_hold:
            If True, the integrator will hold integration when the hold trigger is
            True.
        reset_on_enter_zeno:
            If True, the integrator will reset its state to the initial value
            when the integrator enters the Zeno state.  This option is ignored unless
            `enable_reset` is True.
        zeno_tolerance:
            The tolerance used to determine if the integrator is in the Zeno state.
            If the time between events is less than this tolerance, then the
            integrator is in the Zeno state.  This option is ignored unless
            `enable_reset` is True.


    Events:
        An event is triggered when the "reset" port changes.

        An event is triggered when the state hit one of the limits.

        An event is triggered when the "hold" port changes.

        Another guard is conditionally active when the integrator is in the Zeno
        state, and is triggered when the "reset" port changes from True to False.
        This event is used to exit the Zeno state and resume normal integration.
    """

    @parameters(
        static=[
            "enable_reset",
            "enable_external_reset",
            "enable_limits",
            "enable_hold",
            "reset_on_enter_zeno",
        ],
        dynamic=["zeno_tolerance", "lower_limit", "upper_limit", "initial_state"],
    )
    def __init__(
        self,
        initial_state,
        enable_reset=False,
        enable_limits=False,
        lower_limit=None,
        upper_limit=None,
        enable_hold=False,
        enable_external_reset=False,
        zeno_tolerance=1e-6,
        reset_on_enter_zeno=False,
        dtype=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dtype = dtype
        self.enable_reset = enable_reset
        self.enable_external_reset = enable_external_reset
        self.enable_hold = enable_hold
        self.discrete_state_type = namedtuple(
            "IntegratorDiscreteState", ["zeno", "counter", "tprev"]
        )

        self.xdot_index = self.declare_input_port(name="in_0")

        x0 = cnp.array(initial_state, dtype=self.dtype)
        self.dtype = self.dtype if self.dtype is not None else x0.dtype
        self._continuous_state_idx = self.declare_continuous_state(
            default_value=x0,
            ode=self._ode,
            prerequisites_of_calc=[self.input_ports[self.xdot_index].ticket],
        )

        if enable_reset:
            # Boolean input for triggering reset
            self.reset_trigger_index = self.declare_input_port(name="reset_trigger")
            # prerequisites_of_calc.append(
            #     self.input_ports[self.reset_trigger_index].ticket
            # )

            # Declare a custom discrete state to track Zeno behavior
            self.declare_discrete_state(
                default_value=self.discrete_state_type(
                    zeno=False, counter=0, tprev=0.0
                ),
                as_array=False,
            )

            #
            # Declare reset event
            #
            # when reset is triggered, execute the reset map.
            self.declare_zero_crossing(
                guard=self._reset_guard,
                reset_map=self._reset,
                name="reset_on",
                direction="negative_then_non_negative",
            )
            # when reset is deasserted, do not change the state.
            self.declare_zero_crossing(
                guard=self._reset_guard,
                name="reset_off",
                direction="positive_then_non_positive",
            )

            self.declare_zero_crossing(
                guard=self._exit_zeno_guard,
                reset_map=self._exit_zeno,
                name="exit_zeno",
                direction="positive_then_non_positive",
            )

            # Optional: reset value defined by external signal
            if enable_external_reset:
                self.reset_value_index = self.declare_input_port(name="reset_value")
                # prerequisites_of_calc.append(
                #     self.input_ports[self.reset_value_index].ticket
                # )

        if enable_hold:
            # Boolean input for triggering hold assert/deassert
            self.hold_trigger_index = self.declare_input_port(name="hold_trigger")

            def _hold_guard(_time, _state, *inputs, **_params):
                trigger = inputs[self.hold_trigger_index]
                return cnp.where(trigger, 1.0, -1.0)

            self.declare_zero_crossing(
                guard=_hold_guard,
                name="hold",
                direction="crosses_zero",
            )

        self._output_port_idx = self.declare_output_port(name="out_0")

    def initialize(
        self,
        initial_state,
        enable_reset=False,
        enable_limits=False,
        lower_limit=None,
        upper_limit=None,
        enable_hold=False,
        enable_external_reset=False,
        zeno_tolerance=1e-6,
        reset_on_enter_zeno=False,
    ):
        if self.enable_reset != enable_reset:
            raise ValueError("enable_reset cannot be changed after initialization")
        if self.enable_external_reset != enable_external_reset:
            raise ValueError(
                "enable_external_reset cannot be changed after initialization"
            )
        if self.enable_hold != enable_hold:
            raise ValueError("enable_hold cannot be changed after initialization")

        # Default initial condition unless modified in context
        x0 = cnp.array(initial_state, dtype=self.dtype)
        self.dtype = self.dtype if self.dtype is not None else x0.dtype

        self.configure_continuous_state(
            self._continuous_state_idx,
            default_value=x0,
            ode=self._ode,
            prerequisites_of_calc=[self.input_ports[self.xdot_index].ticket],
        )

        self.reset_on_enter_zeno = reset_on_enter_zeno

        self.enable_limits = enable_limits
        self.has_lower_limit = lower_limit is not None
        self.has_upper_limit = upper_limit is not None

        self.configure_output_port(
            self._output_port_idx,
            self._output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
        )

        if enable_limits:
            if lower_limit is not None:

                def _lower_limit_guard(_time, state, *_inputs, **params):
                    return state.continuous_state - params["lower_limit"]

                self.declare_zero_crossing(
                    guard=_lower_limit_guard,
                    name="lower_limit",
                    direction="positive_then_non_positive",
                )

            if upper_limit is not None:

                def _upper_limit_guard(_time, state, *_inputs, **params):
                    return state.continuous_state - params["upper_limit"]

                self.declare_zero_crossing(
                    guard=_upper_limit_guard,
                    name="upper_limit",
                    direction="negative_then_non_negative",
                )

    def reset_default_values(self, **dynamic_parameters):
        x0 = cnp.array(dynamic_parameters["initial_state"], dtype=self.dtype)
        self.configure_continuous_state_default_value(
            self._continuous_state_idx,
            default_value=x0,
        )

    def _ode(self, _time, state, *inputs, **params):
        # Normally, just integrate the input signal
        xdot = inputs[self.xdot_index]

        # However, if the reset trigger is high or the integrator is in the Zeno state,
        # then the integrator should hold
        if self.enable_reset:
            trigger = inputs[self.reset_trigger_index]
            in_zeno_state = state.discrete_state.zeno
            xdot = cnp.where((trigger | in_zeno_state), cnp.zeros_like(xdot), xdot)

        # Additionally, if the limits are enabled, the derivative is set to zero if
        # either limit is presnetly violated.
        if self.enable_limits:
            xc = state.continuous_state

            if self.has_lower_limit:
                llim_violation = cnp.logical_and(
                    xdot < 0.0, xc <= params["lower_limit"]
                )
            else:
                llim_violation = False

            if self.has_upper_limit:
                ulim_violation = cnp.logical_and(
                    xdot > 0.0, xc >= params["upper_limit"]
                )
            else:
                ulim_violation = False

            xdot = cnp.where(
                (llim_violation | ulim_violation), cnp.zeros_like(xdot), xdot
            )

        if self.enable_hold:
            hold = inputs[self.hold_trigger_index]
            xdot = cnp.where(hold, cnp.zeros_like(xdot), xdot)

        return xdot

    def _output(self, _time, state, *_inputs, **params):
        xc = state.continuous_state
        if self.enable_limits:
            lower_limit = params["lower_limit"] if self.has_lower_limit else -np.inf
            upper_limit = params["upper_limit"] if self.has_upper_limit else np.inf
            return cnp.clip(xc, lower_limit, upper_limit)

        return xc

    def _reset_guard(self, _time, _state, *inputs, **_params):
        trigger = inputs[self.reset_trigger_index]
        return cnp.where(trigger, 1.0, -1.0)

    def _reset(self, time, state, *inputs, **params):
        # If the distance between events is less than the tolerance, then enter the Zeno state.
        dt = time - state.discrete_state.tprev
        zeno = (dt - params["zeno_tolerance"]) <= 0
        tprev = time

        # Handle the reset event as usual
        if self.enable_external_reset:
            xc = inputs[self.reset_value_index]
        else:
            xc = cnp.array(params["initial_state"], dtype=self.dtype)

        # Don't reset if entering Zeno state
        new_continuous_state = cnp.where(
            zeno & (not self.reset_on_enter_zeno),
            state.continuous_state,
            xc,
        )
        state = state.with_continuous_state(new_continuous_state)

        # Count number of resets (for debugging)
        counter = state.discrete_state.counter + 1

        # Update the discrete state
        xd_plus = self.discrete_state_type(zeno=zeno, counter=counter, tprev=tprev)
        state = state.with_discrete_state(xd_plus)

        logger.debug("Resetting to %s", state)
        return state

    def _exit_zeno_guard(self, _time, _state, *inputs, **_params):
        # This will only be active when in the Zeno state.  It monitors the boolean trigger input
        # and will go from 1.0 (when trigger=True) to 0.0 (when trigger=False)
        trigger = inputs[self.reset_trigger_index]
        return cnp.array(trigger, dtype=self.dtype)

    def _exit_zeno(self, _time, state, *_inputs, **_params):
        xd = state.discrete_state._replace(zeno=False)
        return state.with_discrete_state(xd)

    def determine_active_guards(self, root_context):
        # TODO: Update this to use the new zero crossing event system
        # defined in LeafSystem.
        zero_crossing_events = self.zero_crossing_events.mark_all_active()

        if not self.enable_reset:
            return zero_crossing_events

        def _get_reset(events: LeafEventCollection):
            return events.events[0]

        context = root_context[self.system_id]
        in_zeno_state = context.discrete_state.zeno

        reset = cond(
            in_zeno_state,
            lambda e: e.mark_inactive(),
            lambda e: e.mark_active(),
            _get_reset(zero_crossing_events),
        )

        def _get_exit_zeno(events: LeafEventCollection):
            return events.events[1]

        exit_zeno: ZeroCrossingEvent = cond(
            in_zeno_state,
            lambda e: e.mark_active(),
            lambda e: e.mark_inactive(),
            _get_exit_zeno(zero_crossing_events),
        )

        zero_crossing_events = eqx.tree_at(_get_reset, zero_crossing_events, reset)
        zero_crossing_events = eqx.tree_at(
            _get_exit_zeno, zero_crossing_events, exit_zeno
        )

        return zero_crossing_events

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        u = self.eval_input(context)
        xc = context[self.system_id].continuous_state
        check_state_type(
            self,
            inp_data=u,
            state_data=xc,
            error_collector=error_collector,
        )

IntegratorDiscrete

Bases: LeafSystem

Discrete first-order integrator.

This block is a discrete-time approximation to the behavior of the Integrator block. It implements the following linear time-invariant difference equation for input values u and output values y:

    x[k+1] = x[k] + dt * u[k]
    y[k] = x[k]

where x is the state of the integrator. The integrator is initialized with the value of the initial_state parameter.

Options

Reset: the integrator can be configured to reset its state on an input trigger. The reset value can be either the initial state of the integrator or an external value provided by an input port. Limits: the integrator can be configured such that the output and state are constrained by upper and lower limits. Hold: the integrator can be configured to hold integration based on an input trigger.

Unlike the continuous-time integrator, the discrete integrator does not detect Zeno behavior, since this is not a concern in discrete-time systems.

Input ports

(0) The input signal. Must match the shape and dtype of the initial state. (1) The reset trigger. Optional, only if enable_reset is True. (2) The reset value. Optional, only if enable_external_reset is True. (3) The hold trigger. Optional, only if 'enable_hold' is True.

Output ports

(0) The current state of the integrator.

Parameters:

Name Type Description Default
initial_state

The initial value of the integrator state. Can be any array, or even a nested structure of arrays, but the data type should be floating-point.

required
enable_reset

If True, the integrator will reset its state to the initial value when the reset trigger is True. Adds an additional input port for the reset trigger. This signal should be boolean- or binary-valued.

False
enable_external_reset

If True, the integrator will reset its state to the value provided by the reset value input port when the reset trigger is True. Otherwise, the integrator will reset to the initial value. Adds an additional input port for the reset value. This signal should match the shape and dtype of the initial continuous state.

False
enable_limits

If True, the integrator will constrain its state and output to within the upper and lower limits. Either limit may be disbale by setting its value to None.

False
enable_hold

If True, the integrator will hold integration when the hold trigger is True.

False
Source code in collimator/library/primitives.py
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
class IntegratorDiscrete(LeafSystem):
    """Discrete first-order integrator.

    This block is a discrete-time approximation to the behavior of the Integrator
    block.  It implements the following linear time-invariant difference equation
    for input values `u` and output values `y`:
    ```
        x[k+1] = x[k] + dt * u[k]
        y[k] = x[k]
    ```
    where `x` is the state of the integrator.  The integrator is initialized with
    the value of the `initial_state` parameter.

    Options:
        Reset: the integrator can be configured to reset its state on an input
            trigger.  The reset value can be either the initial state of the
            integrator or an external value provided by an input port.
        Limits: the integrator can be configured such that the output and state
            are constrained by upper and lower limits.
        Hold: the integrator can be configured to hold integration based on an
            input trigger.

    Unlike the continuous-time integrator, the discrete integrator does not detect
    Zeno behavior, since this is not a concern in discrete-time systems.

    Input ports:
        (0) The input signal.  Must match the shape and dtype of the initial
            state.
        (1) The reset trigger.  Optional, only if `enable_reset` is True.
        (2) The reset value.  Optional, only if `enable_external_reset` is True.
        (3) The hold trigger. Optional, only if 'enable_hold' is True.

    Output ports:
        (0) The current state of the integrator.

    Parameters:
        initial_state:
            The initial value of the integrator state.  Can be any array, or even
            a nested structure of arrays, but the data type should be floating-point.
        enable_reset:
            If True, the integrator will reset its state to the initial value
            when the reset trigger is True.  Adds an additional input port for
            the reset trigger.  This signal should be boolean- or binary-valued.
        enable_external_reset:
            If True, the integrator will reset its state to the value provided
            by the reset value input port when the reset trigger is True. Otherwise,
            the integrator will reset to the initial value.  Adds an additional
            input port for the reset value.  This signal should match the shape
            and dtype of the initial continuous state.
        enable_limits:
            If True, the integrator will constrain its state and output to within
            the upper and lower limits. Either limit may be disbale by setting its
            value to None.
        enable_hold:
            If True, the integrator will hold integration when the hold trigger is
            True.
    """

    @parameters(
        static=[
            "enable_reset",
            "enable_external_reset",
            "enable_limits",
            "enable_hold",
        ],
        dynamic=["lower_limit", "upper_limit", "initial_state"],
    )
    def __init__(
        self,
        dt,
        initial_state,
        enable_reset=False,
        enable_hold=False,
        enable_limits=False,
        lower_limit=None,
        upper_limit=None,
        enable_external_reset=False,
        dtype=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dt = dt
        self.dtype = dtype

        self.enable_reset = enable_reset
        self.enable_external_reset = enable_external_reset

        self.xdot_index = self.declare_input_port(
            name="in_0"
        )  # One vector-valued input

        self._periodic_update_idx = self.declare_periodic_update()

        if enable_reset:
            self.reset_trigger_index = self.declare_input_port(
                name="reset_trigger"
            )  # Boolean input for triggering reset

            if enable_external_reset:
                self.reset_value_index = self.declare_input_port(
                    name="reset_value"
                )  # Optional reset value

        self.enable_hold = enable_hold
        if enable_hold:
            self.hold_trigger_index = self.declare_input_port(
                name="hold_trigger"
            )  # Boolean input for triggering hold

        self.state_output_index = self.declare_output_port(name="out_0")

    def initialize(
        self,
        initial_state,
        enable_reset=False,
        enable_hold=False,
        enable_limits=False,
        lower_limit=None,
        upper_limit=None,
        enable_external_reset=False,
    ):
        if self.enable_reset != enable_reset:
            raise ValueError("enable_reset cannot be changed after initialization")
        if self.enable_external_reset != enable_external_reset:
            raise ValueError(
                "enable_external_reset cannot be changed after initialization"
            )
        if self.enable_hold != enable_hold:
            raise ValueError("enable_hold cannot be changed after initialization")

        # Default initial condition unless modified in context
        x0 = cnp.array(initial_state, dtype=self.dtype)
        self.dtype = self.dtype if self.dtype is not None else x0.dtype
        self.declare_discrete_state(default_value=x0)
        self.configure_periodic_update(
            self._periodic_update_idx, self._update, period=self.dt, offset=0.0
        )

        # Since the reset is applied to the output port, having this
        # active makes the block feedthrough with respect to related
        # input ports.
        self.is_feedthrough = enable_reset

        self.enable_limits = enable_limits
        self.has_lower_limit = lower_limit is not None
        self.has_upper_limit = upper_limit is not None

        prereqs = [DependencyTicket.xd]
        if enable_reset:
            prereqs.append(self.input_ports[self.reset_trigger_index].ticket)
            if enable_external_reset:
                prereqs.append(self.input_ports[self.reset_value_index].ticket)

        self.configure_output_port(
            self.state_output_index,
            self._output,
            period=self.dt,
            offset=0.0,
            default_value=x0,
            prerequisites_of_calc=prereqs,
        )

    def reset_default_values(self, **dynamic_parameters):
        x0 = cnp.array(dynamic_parameters["initial_state"], dtype=self.dtype)
        self.configure_discrete_state_default_value(default_value=x0)
        self.configure_output_port_default_value(self.state_output_index, x0)

    def _reset(self, *inputs, **params):
        if self.enable_external_reset:
            return inputs[self.reset_value_index]
        return cnp.array(params["initial_state"], dtype=self.dtype)

    def _apply_reset_and_limits(self, x_new, *inputs, **params):
        # Reset and limits are applied to both the update and outputs
        # so that they respond to the discontinuities simultaneously.

        if self.enable_reset:
            # If the reset is high, then return the reset value
            trigger = inputs[self.reset_trigger_index]
            x_new = cnp.where(trigger, self._reset(*inputs, **params), x_new)

        if self.enable_limits:
            lower_limit = params["lower_limit"] if self.has_lower_limit else -cnp.inf
            upper_limit = params["upper_limit"] if self.has_upper_limit else cnp.inf
            x_new = cnp.clip(x_new, lower_limit, upper_limit)

        return x_new

    def _apply_hold(self, x, x_new, *inputs, **_params):
        # Hold is only applied to the update, but not the output

        if self.enable_hold:
            # If the reset is high, then return the reset value
            trigger = inputs[self.hold_trigger_index]
            x_new = cnp.where(trigger, x, x_new)

        return x_new

    def _update(self, _time, state, *inputs, **params):
        x = state.discrete_state
        xdot = inputs[self.xdot_index]
        x_new = x + self.dt * xdot
        x_new = self._apply_hold(x, x_new, *inputs, **params)
        x_new = self._apply_reset_and_limits(x_new, *inputs, **params)
        return x_new.astype(x.dtype)

    def _output(self, _time, state, *inputs, **params):
        x = state.discrete_state
        # To ensure that the discontinuities happen simultaneously with
        # the input signal, also apply the reset and limits to the outputs.
        # this makes the block feedthrough.
        y = self._apply_reset_and_limits(x, *inputs, **params)
        return y

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        u = self.eval_input(context)
        xd = context[self.system_id].discrete_state
        check_state_type(
            self,
            inp_data=u,
            state_data=xd,
            error_collector=error_collector,
        )

KalmanFilter

Bases: KalmanFilterBase

Kalman Filter for the following system:

x[n+1] = A x[n] + B u[n] + G w[n]
y[n]   = C x[n] + D u[n] + v[n]

E(w[n]) = E(v[n]) = 0
E(w[n]w'[n]) = Q
E(v[n]v'[n] = R
E(w[n]v'[n] = N = 0
Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
dt

float Time step of the discrete-time system

required
A

ndarray State transition matrix

required
B

ndarray Input matrix

required
C

ndarray Output matrix. If None, full state output is assumed.

None
D

ndarray Feedthrough matrix. If None, no feedthrough is assumed.

None
G

ndarray Process noise matrix. If None, G=B is assumed.

None
Q

ndarray Process noise covariance matrix. If None, Identity matrix of size compatible with G and A is assumed.

None
R

ndarray Measurement noise covariance matrix. If None, Identity matrix of size compatible with C and A is assumed.

None
x_hat_0

ndarray Initial state estimate. If None, an array of zeros is assumed.

None
P_hat_0

ndarray Initial state covariance matrix estimate. If None, Identity matrix of size identical to A is assumed.

None
Source code in collimator/library/state_estimators/kalman_filter.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
class KalmanFilter(KalmanFilterBase):
    """
    Kalman Filter for the following system:

    ```
    x[n+1] = A x[n] + B u[n] + G w[n]
    y[n]   = C x[n] + D u[n] + v[n]

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Q
    E(v[n]v'[n] = R
    E(w[n]v'[n] = N = 0
    ```

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        dt: float
            Time step of the discrete-time system
        A: ndarray
            State transition matrix
        B: ndarray
            Input matrix
        C: ndarray
            Output matrix. If `None`, full state output is assumed.
        D: ndarray
            Feedthrough matrix. If `None`, no feedthrough is assumed.
        G: ndarray
            Process noise matrix. If `None`, `G=B` is assumed.
        Q: ndarray
            Process noise covariance matrix. If `None`, Identity matrix of size
            compatible with `G` and `A` is assumed.
        R: ndarray
            Measurement noise covariance matrix. If `None`, Identity matrix of size
            compatible with `C` and `A` is assumed.
        x_hat_0: ndarray
            Initial state estimate. If `None`, an array of zeros is assumed.
        P_hat_0: ndarray
            Initial state covariance matrix estimate. If `None`, Identity matrix of size
            identical to `A` is assumed.
    """

    @parameters(
        static=["dt", "A", "B", "C", "D", "G", "Q", "R", "x_hat_0", "P_hat_0"],
    )
    def __init__(
        self,
        dt,
        A,
        B,
        C=None,
        D=None,
        G=None,
        Q=None,
        R=None,
        x_hat_0=None,
        P_hat_0=None,
        name=None,
        **kwargs,
    ):
        is_feedthrough = False if D is None else bool(not cnp.allclose(D, 0.0))
        super().__init__(dt, x_hat_0, P_hat_0, is_feedthrough, name, **kwargs)

    def initialize(
        self,
        dt,
        A,
        B,
        C=None,
        D=None,
        G=None,
        Q=None,
        R=None,
        x_hat_0=None,
        P_hat_0=None,
    ):
        self.nx, self.nu = B.shape

        if C is None:
            C = jnp.eye(self.nx)
            self.ny = self.nx
        else:
            self.ny = C.shape[0]

        if D is None:
            D = jnp.zeros((self.ny, self.nu))
        self.is_feedthrough = bool(not cnp.allclose(D, 0.0))

        if G is None:
            G = B

        _, self.nd = G.shape

        if Q is None:
            Q = jnp.eye(self.nd)

        if R is None:
            R = jnp.eye(self.ny)

        if x_hat_0 is None:
            x_hat_0 = jnp.zeros(self.nx)

        if P_hat_0 is None:
            P_hat_0 = jnp.eye(self.nx)

        check_shape_compatibilities(A, B, C, D, G, Q, R)

        self.A = A
        self.B = B
        self.C = C
        self.D = D
        self.G = G
        self.Q = Q
        self.R = R

        self.eye_x = jnp.eye(self.nx)
        self.GQGT = G @ Q @ G.T

    def _correct(self, time, x_hat_minus, P_hat_minus, *inputs):
        u, y = inputs
        y = jnp.atleast_1d(y)

        C, D = self.C, self.D

        # TODO: improved numerics to avoud computing explicit inverse
        K = P_hat_minus @ C.T @ jnp.linalg.inv(C @ P_hat_minus @ C.T + self.R)

        x_hat_plus = x_hat_minus + jnp.dot(K, y - jnp.dot(C, x_hat_minus))  # n|n

        if self.is_feedthrough:
            u = cnp.atleast_1d(u)
            x_hat_plus = x_hat_plus - cnp.dot(self.K, cnp.dot(D, u))

        P_hat_plus = jnp.matmul(self.eye_x - jnp.matmul(K, C), P_hat_minus)  # n|n

        return x_hat_plus, P_hat_plus

    def _propagate(self, time, x_hat_plus, P_hat_plus, *inputs):
        # Predict -- x_hat_plus of current step is propagated to be the
        # x_hat_minus of the next step
        # n+1|n in current step is n|n-1 for next step

        u, y = inputs
        u = jnp.atleast_1d(u)

        A, B = self.A, self.B

        x_hat_minus = jnp.dot(A, x_hat_plus) + jnp.dot(B, u)  # n+1|n
        P_hat_minus = A @ P_hat_plus @ A.T + self.GQGT  # n+1|n

        return x_hat_minus, P_hat_minus

    #######################################
    # Make filter for a continuous plant  #
    #######################################

    @staticmethod
    @with_resolved_parameters
    def for_continuous_plant(
        plant,
        x_eq,
        u_eq,
        dt,
        Q=None,
        R=None,
        G=None,
        x_hat_bar_0=None,
        P_hat_bar_0=None,
        discretization_method="euler",
        discretized_noise=False,
        name=None,
        ui_id=None,
    ):
        """
        Obtain a Kalman Filter system for a continuous-time plant after linearization
        at equilibrium point (x_eq, u_eq)

        The input plant contains the deterministic forms of the forward and observation
        operators:

        ```
            dx/dt = f(x,u)
            y = g(x,u)
        ```

        Note: (i) Only plants with one vector-valued input and one vector-valued output
        are currently supported. Furthermore, the plant LeafSystem/Diagram should have
        only one vector-valued integrator.

        A plant with disturbances of the following form is then considered
        following form:

        ```
            dx/dt = f(x,u) + G w                        --- (C1)
            y = g(x,u) +  v                             --- (C2)
        ```

        where:

            `w` represents the process noise,
            `v` represents the measurement noise,

        and

        ```
            E(w) = E(v) = 0
            E(ww') = Q
            E(vv') = R
            E(wv') = N = 0
        ```

        This plant with disturbances is linearized (only `f` and `g`) around the
        equilibrium point to obtain:

        ```
            d/dt (x_bar) = A x_bar + B u_bar + G w
            y_bar = C x_bar + D u_bar + v
        ```

        where,

        ```
            x_bar = x - x_eq
            u_bar = u - u_eq
            y_bar = y - y_bar
            y_eq = g(x_eq, u_eq)
        ```

        The linearized plant is then discretized via `euler` or `zoh` method to obtain:

        ```
            x_bar[n] = Ad x_bar[n] + Bd u_bar[n] + Gd w[n]           --- (L1)
            y_bar[n] = Cd x_bar[n] + Dd u_bar[n] + v[n]              --- (L2)

            E(w[n]) = E(v[n]) = 0
            E(w[n]w'[n]) = Qd
            E(v[n]v'[n]) = Rd
            E(w[n]v'[n]) = Nd = 0
        ```

        Note: If `discretized_noise` is True, then it is assumed that the user is
        providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
        continuous-time Q, R, and G, and Gd is set to Identity matrix.

        A Kalman Filter estimator for the system of equations (L1) and (L2) is
        created and returned. This filter is in the `x_bar`, `u_bar`, and `y_bar`
        states.

        This returned system will have

        Input ports:
            (0) u_bar[n] : control vector at timestep n, relative to equilibrium
            (1) y_bar[n] : measurement vector at timestep n, relative to equilibrium

        Output ports:
            (1) x_hat_bar[n] : state vector estimate at timestep n, relative to
                               equilibrium

        Parameters:
            plant : a `Plant` object which can be a LeafSystem or a Diagram.
            x_eq: ndarray
                Equilibrium state vector for discretization
            u_eq: ndarray
                Equilibrium control vector for discretization
            dt: float
                Time step for the discretization.
            Q: ndarray
                Process noise covariance matrix. If `None`, Identity matrix of size
                compatible with `G` and and linearized system's `A` is assumed.
            R: ndarray
                Measurement noise covariance matrix. If `None`, Identity matrix of size
                compatible with linearized system's `C` and `A` is assumed.
            G: ndarray
                Process noise matrix. If `None`, `G=B` is assumed making disrurbances
                additive to control vector `u`, i.e. `u_disturbed = u_orig + w`.
            x_hat_bar_0: ndarray
                Initial state estimate, relative to equilirium.
                If None, an identity matrix is assumed.
            P_hat_bar_0: ndarray
                Initial covariance matrix estimate for state, relative to equilibrium.
                If `None`, an Identity matrix is assumed.
            discretization_method: str ("euler" or "zoh")
                Method to discretize the continuous-time plant. Default is "euler".
            discretized_noise: bool
                Whether the user is directly providing Gd, Qd and Rd. Default is False.
                If True, `G`, `Q`, and `R` are assumed to be Gd, Qd, and Rd,
                respectively.
        """
        (
            y_eq,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
        ) = linearize_and_discretize_continuous_plant(
            plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
        )

        check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

        nx = x_eq.size

        if x_hat_bar_0 is None:
            x_hat_bar_0 = jnp.zeros(nx)

        if P_hat_bar_0 is None:
            P_hat_bar_0 = jnp.eye(nx)

        # Instantiate a Kalman Filter for the linearized plant
        kf = KalmanFilter(
            dt,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
            x_hat_bar_0,
            P_hat_bar_0,
            name=name,
            ui_id=ui_id,
        )

        return y_eq, kf

    ##############################################
    # Make global filter for a continuous plant  #
    ##############################################

    @staticmethod
    @with_resolved_parameters
    def global_filter_for_continuous_plant(
        plant,
        x_eq,
        u_eq,
        dt,
        Q=None,
        R=None,
        G=None,
        x_hat_0=None,
        P_hat_0=None,
        discretization_method="euler",
        discretized_noise=False,
        name=None,
        ui_id=None,
    ):
        """
        See docs for `for_continuous_plant`, which returns the local Kalman
        Filter. This method additionally converts the local Kalman Filter to a
        global estimator. See docs for `make_global_estimator_from_local` for details.
        """
        (
            y_eq,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
        ) = linearize_and_discretize_continuous_plant(
            plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
        )

        check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

        nx = x_eq.size

        if x_hat_0 is None:
            x_hat_bar_0 = jnp.zeros(nx)
        else:
            x_hat_bar_0 = x_hat_0 - x_eq

        if P_hat_0 is None:
            P_hat_bar_0 = jnp.eye(nx)
        else:
            P_hat_bar_0 = P_hat_0

        # Instantiate a Kalman Filter for the linearized plant
        local_kf = KalmanFilter(
            dt,
            Ad,
            Bd,
            Cd,
            Dd,
            Gd,
            Qd,
            Rd,
            x_hat_bar_0,
            P_hat_bar_0,
            name=name + "_local" if name is not None else None,
        )

        global_kf = make_global_estimator_from_local(
            local_kf,
            x_eq,
            u_eq,
            y_eq,
            name=name,
            ui_id=ui_id,
        )

        return global_kf

for_continuous_plant(plant, x_eq, u_eq, dt, Q=None, R=None, G=None, x_hat_bar_0=None, P_hat_bar_0=None, discretization_method='euler', discretized_noise=False, name=None, ui_id=None) staticmethod

Obtain a Kalman Filter system for a continuous-time plant after linearization at equilibrium point (x_eq, u_eq)

The input plant contains the deterministic forms of the forward and observation operators:

    dx/dt = f(x,u)
    y = g(x,u)

Note: (i) Only plants with one vector-valued input and one vector-valued output are currently supported. Furthermore, the plant LeafSystem/Diagram should have only one vector-valued integrator.

A plant with disturbances of the following form is then considered following form:

    dx/dt = f(x,u) + G w                        --- (C1)
    y = g(x,u) +  v                             --- (C2)

where:

`w` represents the process noise,
`v` represents the measurement noise,

and

    E(w) = E(v) = 0
    E(ww') = Q
    E(vv') = R
    E(wv') = N = 0

This plant with disturbances is linearized (only f and g) around the equilibrium point to obtain:

    d/dt (x_bar) = A x_bar + B u_bar + G w
    y_bar = C x_bar + D u_bar + v

where,

    x_bar = x - x_eq
    u_bar = u - u_eq
    y_bar = y - y_bar
    y_eq = g(x_eq, u_eq)

The linearized plant is then discretized via euler or zoh method to obtain:

    x_bar[n] = Ad x_bar[n] + Bd u_bar[n] + Gd w[n]           --- (L1)
    y_bar[n] = Cd x_bar[n] + Dd u_bar[n] + v[n]              --- (L2)

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Qd
    E(v[n]v'[n]) = Rd
    E(w[n]v'[n]) = Nd = 0

Note: If discretized_noise is True, then it is assumed that the user is providing Gd, Qd and Rd. If False, then Qd and Rd are computed from continuous-time Q, R, and G, and Gd is set to Identity matrix.

A Kalman Filter estimator for the system of equations (L1) and (L2) is created and returned. This filter is in the x_bar, u_bar, and y_bar states.

This returned system will have

Input ports

(0) u_bar[n] : control vector at timestep n, relative to equilibrium (1) y_bar[n] : measurement vector at timestep n, relative to equilibrium

Output ports

(1) x_hat_bar[n] : state vector estimate at timestep n, relative to equilibrium

Parameters:

Name Type Description Default
plant

a Plant object which can be a LeafSystem or a Diagram.

required
x_eq

ndarray Equilibrium state vector for discretization

required
u_eq

ndarray Equilibrium control vector for discretization

required
dt

float Time step for the discretization.

required
Q

ndarray Process noise covariance matrix. If None, Identity matrix of size compatible with G and and linearized system's A is assumed.

None
R

ndarray Measurement noise covariance matrix. If None, Identity matrix of size compatible with linearized system's C and A is assumed.

None
G

ndarray Process noise matrix. If None, G=B is assumed making disrurbances additive to control vector u, i.e. u_disturbed = u_orig + w.

None
x_hat_bar_0

ndarray Initial state estimate, relative to equilirium. If None, an identity matrix is assumed.

None
P_hat_bar_0

ndarray Initial covariance matrix estimate for state, relative to equilibrium. If None, an Identity matrix is assumed.

None
discretization_method

str ("euler" or "zoh") Method to discretize the continuous-time plant. Default is "euler".

'euler'
discretized_noise

bool Whether the user is directly providing Gd, Qd and Rd. Default is False. If True, G, Q, and R are assumed to be Gd, Qd, and Rd, respectively.

False
Source code in collimator/library/state_estimators/kalman_filter.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
@staticmethod
@with_resolved_parameters
def for_continuous_plant(
    plant,
    x_eq,
    u_eq,
    dt,
    Q=None,
    R=None,
    G=None,
    x_hat_bar_0=None,
    P_hat_bar_0=None,
    discretization_method="euler",
    discretized_noise=False,
    name=None,
    ui_id=None,
):
    """
    Obtain a Kalman Filter system for a continuous-time plant after linearization
    at equilibrium point (x_eq, u_eq)

    The input plant contains the deterministic forms of the forward and observation
    operators:

    ```
        dx/dt = f(x,u)
        y = g(x,u)
    ```

    Note: (i) Only plants with one vector-valued input and one vector-valued output
    are currently supported. Furthermore, the plant LeafSystem/Diagram should have
    only one vector-valued integrator.

    A plant with disturbances of the following form is then considered
    following form:

    ```
        dx/dt = f(x,u) + G w                        --- (C1)
        y = g(x,u) +  v                             --- (C2)
    ```

    where:

        `w` represents the process noise,
        `v` represents the measurement noise,

    and

    ```
        E(w) = E(v) = 0
        E(ww') = Q
        E(vv') = R
        E(wv') = N = 0
    ```

    This plant with disturbances is linearized (only `f` and `g`) around the
    equilibrium point to obtain:

    ```
        d/dt (x_bar) = A x_bar + B u_bar + G w
        y_bar = C x_bar + D u_bar + v
    ```

    where,

    ```
        x_bar = x - x_eq
        u_bar = u - u_eq
        y_bar = y - y_bar
        y_eq = g(x_eq, u_eq)
    ```

    The linearized plant is then discretized via `euler` or `zoh` method to obtain:

    ```
        x_bar[n] = Ad x_bar[n] + Bd u_bar[n] + Gd w[n]           --- (L1)
        y_bar[n] = Cd x_bar[n] + Dd u_bar[n] + v[n]              --- (L2)

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Qd
        E(v[n]v'[n]) = Rd
        E(w[n]v'[n]) = Nd = 0
    ```

    Note: If `discretized_noise` is True, then it is assumed that the user is
    providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
    continuous-time Q, R, and G, and Gd is set to Identity matrix.

    A Kalman Filter estimator for the system of equations (L1) and (L2) is
    created and returned. This filter is in the `x_bar`, `u_bar`, and `y_bar`
    states.

    This returned system will have

    Input ports:
        (0) u_bar[n] : control vector at timestep n, relative to equilibrium
        (1) y_bar[n] : measurement vector at timestep n, relative to equilibrium

    Output ports:
        (1) x_hat_bar[n] : state vector estimate at timestep n, relative to
                           equilibrium

    Parameters:
        plant : a `Plant` object which can be a LeafSystem or a Diagram.
        x_eq: ndarray
            Equilibrium state vector for discretization
        u_eq: ndarray
            Equilibrium control vector for discretization
        dt: float
            Time step for the discretization.
        Q: ndarray
            Process noise covariance matrix. If `None`, Identity matrix of size
            compatible with `G` and and linearized system's `A` is assumed.
        R: ndarray
            Measurement noise covariance matrix. If `None`, Identity matrix of size
            compatible with linearized system's `C` and `A` is assumed.
        G: ndarray
            Process noise matrix. If `None`, `G=B` is assumed making disrurbances
            additive to control vector `u`, i.e. `u_disturbed = u_orig + w`.
        x_hat_bar_0: ndarray
            Initial state estimate, relative to equilirium.
            If None, an identity matrix is assumed.
        P_hat_bar_0: ndarray
            Initial covariance matrix estimate for state, relative to equilibrium.
            If `None`, an Identity matrix is assumed.
        discretization_method: str ("euler" or "zoh")
            Method to discretize the continuous-time plant. Default is "euler".
        discretized_noise: bool
            Whether the user is directly providing Gd, Qd and Rd. Default is False.
            If True, `G`, `Q`, and `R` are assumed to be Gd, Qd, and Rd,
            respectively.
    """
    (
        y_eq,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
    ) = linearize_and_discretize_continuous_plant(
        plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
    )

    check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

    nx = x_eq.size

    if x_hat_bar_0 is None:
        x_hat_bar_0 = jnp.zeros(nx)

    if P_hat_bar_0 is None:
        P_hat_bar_0 = jnp.eye(nx)

    # Instantiate a Kalman Filter for the linearized plant
    kf = KalmanFilter(
        dt,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
        x_hat_bar_0,
        P_hat_bar_0,
        name=name,
        ui_id=ui_id,
    )

    return y_eq, kf

global_filter_for_continuous_plant(plant, x_eq, u_eq, dt, Q=None, R=None, G=None, x_hat_0=None, P_hat_0=None, discretization_method='euler', discretized_noise=False, name=None, ui_id=None) staticmethod

See docs for for_continuous_plant, which returns the local Kalman Filter. This method additionally converts the local Kalman Filter to a global estimator. See docs for make_global_estimator_from_local for details.

Source code in collimator/library/state_estimators/kalman_filter.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@staticmethod
@with_resolved_parameters
def global_filter_for_continuous_plant(
    plant,
    x_eq,
    u_eq,
    dt,
    Q=None,
    R=None,
    G=None,
    x_hat_0=None,
    P_hat_0=None,
    discretization_method="euler",
    discretized_noise=False,
    name=None,
    ui_id=None,
):
    """
    See docs for `for_continuous_plant`, which returns the local Kalman
    Filter. This method additionally converts the local Kalman Filter to a
    global estimator. See docs for `make_global_estimator_from_local` for details.
    """
    (
        y_eq,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
    ) = linearize_and_discretize_continuous_plant(
        plant, x_eq, u_eq, dt, Q, R, G, discretization_method, discretized_noise
    )

    check_shape_compatibilities(Ad, Bd, Cd, Dd, Gd, Qd, Rd)

    nx = x_eq.size

    if x_hat_0 is None:
        x_hat_bar_0 = jnp.zeros(nx)
    else:
        x_hat_bar_0 = x_hat_0 - x_eq

    if P_hat_0 is None:
        P_hat_bar_0 = jnp.eye(nx)
    else:
        P_hat_bar_0 = P_hat_0

    # Instantiate a Kalman Filter for the linearized plant
    local_kf = KalmanFilter(
        dt,
        Ad,
        Bd,
        Cd,
        Dd,
        Gd,
        Qd,
        Rd,
        x_hat_bar_0,
        P_hat_bar_0,
        name=name + "_local" if name is not None else None,
    )

    global_kf = make_global_estimator_from_local(
        local_kf,
        x_eq,
        u_eq,
        y_eq,
        name=name,
        ui_id=ui_id,
    )

    return global_kf

LTISystem

Bases: LTISystemBase

Continuous-time linear time-invariant system.

Implements the following system of ODEs:

    ẋ = Ax + Bu
    y = Cx + Du
Input ports

(0) u: Input vector of size m

Output ports

(0) y: Output vector of size p. Note that this is feedthrough from the input port if and only if D is nonzero.

Parameters:

Name Type Description Default
A

State matrix of size n x n

required
B

Input matrix of size n x m

required
C

Output matrix of size p x n

required
D

Feedthrough matrix of size p x m

required
initialize_states

Initial state vector of size n (default: 0)

None
Source code in collimator/library/linear_system.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class LTISystem(LTISystemBase):
    """Continuous-time linear time-invariant system.

    Implements the following system of ODEs:
    ```
        ẋ = Ax + Bu
        y = Cx + Du
    ```

    Input ports:
        (0) u: Input vector of size m

    Output ports:
        (0) y: Output vector of size p.  Note that this is feedthrough from the input
            port if and only if D is nonzero.

    Parameters:
        A: State matrix of size n x n
        B: Input matrix of size n x m
        C: Output matrix of size p x n
        D: Feedthrough matrix of size p x m
        initialize_states: Initial state vector of size n (default: 0)
    """

    @parameters(dynamic=["A", "B", "C", "D"], static=["initialize_states"])
    def __init__(self, A, B, C, D, initialize_states=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._output_port_idx = self.declare_output_port(
            self._eval_output
        )  # Single output port (y)
        self._continuous_state_id = (
            self.declare_continuous_state()
        )  # Single continuous state (x)

    def _init_state(self, A, B, C, D, initialize_states=None):
        super()._init_state(A, B, C, D, initialize_states)
        self.configure_output_port(
            self._output_port_idx,
            self._eval_output,
            default_value=cnp.zeros(self.p) if self.p > 1 else 0.0,
            requires_inputs=self.is_feedthrough,
        )
        self.configure_continuous_state(
            self._continuous_state_id,
            ode=self.ode,
            default_value=self.initialize_states,
        )

    def initialize(self, A, B, C, D, initialize_states=None, **kwargs):
        self._init_state(A, B, C, D, initialize_states)
        self.parameters["A"].set(self.A)
        self.parameters["B"].set(self.B)
        self.parameters["C"].set(self.C)
        self.parameters["D"].set(self.D)

    def _eval_output(self, time, state, *inputs, **params):
        self.C, self.D = params["C"], params["D"]
        return self._eval_output_base(self.C, self.D, state, *inputs)

    def _eval_output_base(self, C, D, state, *inputs):
        x = state.continuous_state
        y = cnp.matmul(C, cnp.atleast_1d(x))

        if self.is_feedthrough:
            (u,) = inputs
            y += cnp.matmul(D, cnp.atleast_1d(u))

        # Handle the special case of scalar output
        if self.scalar_output:
            y = cnp.atleast_1d(y)[0]

        return y

    def ode(self, time, state, u, **params):
        x = state.continuous_state
        self.A, self.B = params["A"], params["B"]
        Ax = cnp.matmul(self.A, cnp.atleast_1d(x))
        Bu = cnp.matmul(self.B, cnp.atleast_1d(u))
        return Ax + Bu

    @property
    def ss(self):
        """State-space representation of the system."""
        return control.ss(self.A, self.B, self.C, self.D)

ss property

State-space representation of the system.

LTISystemDiscrete

Bases: LTISystemBase

Discrete-time linear time-invariant system.

Implements the following system of ODEs:

    x[k+1] = A x[k] + B u[k]
    y[k] = C x[k] + D u[k]
Input ports

(0) u[k]: Input vector of size m

Output ports

(0) y[k]: Output vector of size p. Note that this is feedthrough from the input port if and only if D is nonzero.

Parameters:

Name Type Description Default
A

State matrix of size n x n

required
B

Input matrix of size n x m

required
C

Output matrix of size p x n

required
D

Feedthrough matrix of size p x m

required
dt

Sampling period

required
initialize_states

Initial state vector of size n (default: 0)

None
Source code in collimator/library/linear_system.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
class LTISystemDiscrete(LTISystemBase):
    """Discrete-time linear time-invariant system.

    Implements the following system of ODEs:
    ```
        x[k+1] = A x[k] + B u[k]
        y[k] = C x[k] + D u[k]
    ```

    Input ports:
        (0) u[k]: Input vector of size m

    Output ports:
        (0) y[k]: Output vector of size p.  Note that this is feedthrough from the
                  input port if and only if D is nonzero.

    Parameters:
        A: State matrix of size n x n
        B: Input matrix of size n x m
        C: Output matrix of size p x n
        D: Feedthrough matrix of size p x m
        dt: Sampling period
        initialize_states: Initial state vector of size n (default: 0)
    """

    @parameters(dynamic=["A", "B", "C", "D"], static=["initialize_states"])
    def __init__(self, A, B, C, D, dt, initialize_states=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.dt = dt
        self.declare_periodic_update(
            self._update,
            period=dt,
            offset=0.0,
        )

        self._output_port_idx = self.declare_output_port(
            self._eval_output
        )  # Single output port (y)

    def _init_state(self, A, B, C, D, initialize_states=None):
        super()._init_state(A, B, C, D, initialize_states)
        self.declare_discrete_state(
            default_value=self.initialize_states,
        )  # Single discrete state (x)
        self.configure_output_port(
            self._output_port_idx,
            self._eval_output,
            period=self.dt,
            offset=0.0,
            default_value=cnp.zeros(self.p) if self.p > 1 else 0.0,
            requires_inputs=self.is_feedthrough,
        )

    def initialize(self, A, B, C, D, initialize_states=None, **kwargs):
        self._init_state(A, B, C, D, initialize_states)
        self.parameters["A"].set(self.A)
        self.parameters["B"].set(self.B)
        self.parameters["C"].set(self.C)
        self.parameters["D"].set(self.D)

    def _eval_output(self, time, state, *inputs, **params):
        x = state.discrete_state
        self.C, self.D = params["C"], params["D"]
        y = cnp.matmul(self.C, cnp.atleast_1d(x))

        if self.is_feedthrough:
            (u,) = inputs
            y += cnp.matmul(self.D, cnp.atleast_1d(u))

        # Handle the special case of scalar output
        if self.scalar_output:
            y = y[0]

        return y

    def _update(self, time, state, u, **params):
        x = state.discrete_state
        self.A, self.B = params["A"], params["B"]
        Ax = cnp.matmul(self.A, cnp.atleast_1d(x))
        Bu = cnp.matmul(self.B, cnp.atleast_1d(u))
        return Ax + Bu

    @property
    def ss(self):
        """State-space representation of the system."""
        return control.ss(self.A, self.B, self.C, self.D, self.dt)

ss property

State-space representation of the system.

LinearDiscreteTimeMPC

Bases: LeafSystem

Model predictive control for a linear discrete-time system.

Notes

This block is feedthrough, meaning that every time the output port is evaluated, the solver is run. This is in order to avoid a "data flow" delay between the solver and the output port (see also the PIDDiscrete block). This means that either the input or output signal should be discrete-time in order for the block to work as intended. Ideally, the output signal should be passed to a zero-order hold block so that the solver only needs to be run once per step.

Source code in collimator/library/mpc.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
class LinearDiscreteTimeMPC(LeafSystem):
    """Model predictive control for a linear discrete-time system.

    Notes:
        This block is _feedthrough_, meaning that every time the output port is
        evaluated, the solver is run.  This is in order to avoid a "data flow" delay
        between the solver and the output port (see also the `PIDDiscrete` block).
        This means that either the input or output signal should be discrete-time in
        order for the block to work as intended.  Ideally, the output signal should
        be passed to a zero-order hold block so that the solver only needs to be run
        once per step.
    """

    def __init__(
        self,
        lin_sys,
        Q,
        R,
        N,
        dt,
        x_ref,
        lbu=-np.inf,
        ubu=np.inf,
        name=None,
        warm_start=False,
    ):
        super().__init__(name=name)
        lin_sys.create_context()
        self.n = lin_sys.A.shape[0]
        self.m = lin_sys.B.shape[1]
        self.N = N
        self.warm_start = warm_start

        # Convert to discrete time with Euler discretization
        A = jnp.eye(self.n) + dt * lin_sys.A
        B = dt * lin_sys.B

        # Input: current state (x0)
        self.declare_input_port()

        self.solve, init_params = self._make_solver(A, B, Q, R, lbu, ubu, N, x_ref)

        # Declare a feedthrough output port for the solver
        self.declare_output_port(
            self.solve,
            requires_inputs=True,
            period=dt,
            offset=0.0,
        )

    def _make_solver(self, A, B, Q, R, lbu, ubu, N, xf):
        from jax.experimental import sparse

        n = self.n
        m = self.m

        # Identity matrices of state and control dimension
        I_A = jnp.eye(n)
        I_B = jnp.eye(m)

        def e(k):
            """Unit vector in the kth direction"""
            return jnp.zeros(N).at[k].set(1.0)

        blocks = [Q, R] * N
        P = linalg.block_diag(*blocks)

        # The initial condition constraint is x[0] = x0
        L0 = jnp.eye(n, N * (n + m))

        # The defect constraint for step k is
        #    0 = (A * x[k] + B * u[k]) - x[k+1]
        L_defect = jnp.vstack(
            [
                jnp.kron(e(k), jnp.hstack([A, B]))
                + jnp.kron(e(k + 1), jnp.hstack([-I_A, 0 * B]))
                for k in range(N - 1)
            ]
        )

        # Constraint on terminal state
        Lf = jnp.kron(e(N - 1), jnp.hstack([I_A, 0 * B]))

        # Constraints on the control input
        L_input = jnp.vstack(
            [jnp.kron(e(k), jnp.hstack([0 * B.T, I_B])) for k in range(N)]
        )

        # Stack the constraint matrices and define bounds
        #  lb <= Lx <= ub
        L = jnp.vstack([L0, L_defect, Lf, L_input])

        def _get_bounds(x0):
            lb = jnp.hstack(
                [x0, jnp.zeros(L_defect.shape[0]), xf, jnp.full(N * m, lbu)]
            )
            ub = jnp.hstack(
                [x0, jnp.zeros(L_defect.shape[0]), xf, jnp.full(N * m, ubu)]
            )
            return lb, ub

        # self.qp = jaxopt.BoxOSQP(matvec_Q=_matvec_Q, matvec_A=_matvec_A)
        c = jnp.zeros(N * (n + m))

        # qp = jaxopt.BoxOSQP()

        P_sp = sparse.BCOO.fromdense(P)
        L_sp = sparse.BCOO.fromdense(L)

        # @sparse.sparsify
        @jax.jit
        def _matvec_Q(params_Q, x):
            """Matrix-vector product Q * x"""
            return P_sp @ x

        @jax.jit
        def _matvec_A(params_A, x):
            """Matrix-vector product A * x"""
            return L_sp @ x

        self.qp = jaxopt.BoxOSQP(matvec_Q=_matvec_Q, matvec_A=_matvec_A)

        lb, ub = _get_bounds(xf)
        z0 = jnp.zeros(N * (n + m))
        init_params = self.qp.init_params(
            z0, params_obj=(None, c), params_eq=None, params_ineq=(lb, ub)
        )

        def _solve(time, state, x0):
            lb, ub = _get_bounds(x0)

            if self.warm_start:
                raise NotImplementedError(
                    "Warm start not yet supported for JAX MPC block"
                )
            else:
                init_params = None
            # sol = qp.run(params_obj=(P, c), params_eq=L, params_ineq=(lb, ub)).params
            # sol = self.qp.run(params_obj=(None, c), params_ineq=(lb, ub)).params
            osqp_params = self.qp.run(
                init_params=init_params,
                params_obj=(None, c),
                params_ineq=(lb, ub),
            ).params

            xu_traj = osqp_params.primal[0].reshape(
                (self.n + self.m, self.N), order="F"
            )

            # Time series of control inputs
            u_opt = xu_traj[self.n :, :]

            # Return the first control value only
            return u_opt[:, 0]

        return jax.jit(_solve), init_params

LinearDiscreteTimeMPC_OSQP

Bases: LeafSystem

Same as above, but using OSQP. This is an example of a case where a traced array gets passed to a function that doesn't know how to handle it.

Source code in collimator/library/mpc.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
class LinearDiscreteTimeMPC_OSQP(LeafSystem):
    """
    Same as above, but using OSQP.  This is an example of a case where a traced array gets passed
    to a function that doesn't know how to handle it.
    """

    def __init__(
        self,
        lin_sys,
        Q,
        R,
        N,
        dt,
        x_ref,
        lbu=-np.inf,
        ubu=np.inf,
        name=None,
    ):
        super().__init__(name=name)
        lin_sys.create_context()
        self.n = lin_sys.A.shape[0]
        self.m = lin_sys.B.shape[1]
        self.N = N

        # Convert to discrete time with Euler discretization
        A = jnp.eye(self.n) + dt * lin_sys.A
        B = dt * lin_sys.B

        self._make_solver(A, B, Q, R, lbu, ubu, N, x_ref)

        # Input: current state (x0)
        self.declare_input_port()

        self._result_template = jnp.zeros((self.n + self.m) * self.N)

        # Wrap the solve call as a JAX "pure callback" so that it can call
        # arbitrary non-JAX Python code (in this case IPOPT).
        self._solve = partial(jax.pure_callback, self.solve, self._result_template)

        self.declare_output_port(
            self._output,
            period=dt,
            offset=0.0,
            requires_inputs=True,
        )

    def _extract_u_opt(self, xu_flat_traj):
        xu_traj = np.reshape(xu_flat_traj, (self.n + self.m, self.N), order="F")

        # Split solution into states and controls
        u_opt = xu_traj[self.n :, :]

        # Return current best projected action
        return u_opt[:, 0]

    def _output(self, time, state, *inputs):
        """Output callback used when the block is in "feedthrough" mode."""
        args = (time, state, *inputs)
        u_flat_traj = cond(jnp.isinf(time), self._dummy_solve, self._solve, *args)
        return self._extract_u_opt(u_flat_traj)

    def _dummy_solve(self, _time, _state, *_inputs, **_params):
        """Safeguard for reconstructing the results during ODE solver minor steps.

        This can result in `inf` values passed to the ODE solver, which will raise
        errors in IPOPT.  Instead, we can just return another `inf` value of the
        right shape here.
        """
        return jnp.full(self._result_template.shape, jnp.inf)

    def solve(self, time, state, x0):
        # pylint: disable=not-callable
        lb, ub = self.get_bounds(x0)

        self.solver.update(l=np.array(lb), u=np.array(ub))

        # Solve problem
        sol = self.solver.solve()

        return sol.x

    def _make_solver(self, A, B, Q, R, lbu, ubu, N, xf):
        from scipy import sparse

        n = self.n
        m = self.m

        # Identity matrices of state and control dimension
        I_A = jnp.eye(n)
        I_B = jnp.eye(m)

        def e(k):
            """Unit vector in the kth direction"""
            return jnp.zeros(N).at[k].set(1.0)

        blocks = [Q, R] * N
        P = linalg.block_diag(*blocks)

        # The initial condition constraint is x[0] = x0
        L0 = jnp.eye(n, N * (n + m))

        # The defect constraint for step k is
        #    0 = (A * x[k] + B * u[k]) - x[k+1]
        L_defect = jnp.vstack(
            [
                jnp.kron(e(k), jnp.hstack([A, B]))
                + jnp.kron(e(k + 1), jnp.hstack([-I_A, 0 * B]))
                for k in range(N - 1)
            ]
        )

        # Constraint on terminal state
        Lf = jnp.kron(e(N - 1), jnp.hstack([I_A, 0 * B]))

        # Constraints on the control input
        L_input = jnp.vstack(
            [jnp.kron(e(k), jnp.hstack([0 * B.T, I_B])) for k in range(N)]
        )

        # Stack the constraint matrices and define bounds
        #  lb <= Lx <= ub
        L = jnp.vstack([L0, L_defect, Lf, L_input])

        def get_bounds(x0):
            lb = jnp.hstack(
                [x0, jnp.zeros(L_defect.shape[0]), xf, jnp.full(N * m, lbu)]
            )
            ub = jnp.hstack(
                [x0, jnp.zeros(L_defect.shape[0]), xf, jnp.full(N * m, ubu)]
            )
            return lb, ub

        self.get_bounds = jax.jit(get_bounds)
        self.solver = osqp.OSQP()

        lb, ub = get_bounds(jnp.zeros(n))  # Initialize solver with dummy variables
        self.solver.setup(
            P=sparse.csc_matrix(P),
            A=sparse.csc_matrix(L),
            l=np.array(lb),
            u=np.array(ub),
            verbose=False,
        )

LinearQuadraticRegulator

Bases: FeedthroughBlock

Linear Quadratic Regulator (LQR) for a continuous-time system: dx/dt = A x + B u. Computes the optimal control input: u = -K x, where u minimises the cost function over [0, ∞)]: J = ∫(x.T Q x + u.T R u) dt.

Input ports

(0) x: state vector of the system.

Output ports

(0) u: optimal control vector.

Parameters:

Name Type Description Default
A

Array State matrix of the system.

required
B

Array Input matrix of the system.

required
Q

Array State cost matrix.

required
R

Array Input cost matrix.

required
Source code in collimator/library/lqr.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class LinearQuadraticRegulator(FeedthroughBlock):
    """
    Linear Quadratic Regulator (LQR) for a continuous-time system:
            dx/dt = A x + B u.
    Computes the optimal control input:
            u = -K x,
    where u minimises the cost function over [0, ∞)]:
            J = ∫(x.T Q x + u.T R u) dt.

    Input ports:
        (0) x: state vector of the system.

    Output ports:
        (0) u: optimal control vector.

    Parameters:
        A: Array
            State matrix of the system.
        B: Array
            Input matrix of the system.
        Q: Array
            State cost matrix.
        R: Array
            Input cost matrix.
    """

    def __init__(self, A, B, Q, R, *args, **kwargs):
        self.K, S, E = control.lqr(A, B, Q, R)
        super().__init__(lambda x: jnp.matmul(-self.K, x), *args, **kwargs)

Logarithm

Bases: FeedthroughBlock

Compute the logarithm of the input signal.

This block dispatches to jax.numpy.log, jax.numpy.log2, or jax.numpy.log10, so the semantics, broadcasting rules, etc. are the same. See the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log.html https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log2.html https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log10.html

Input ports

(0) The input signal.

Output ports

(0) The logarithm of the input signal.

Parameters:

Name Type Description Default
base

One of "natural", "2", or "10". Determines the base of the logarithm. The default is "natural".

'natural'
Source code in collimator/library/primitives.py
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
class Logarithm(FeedthroughBlock):
    """Compute the logarithm of the input signal.

    This block dispatches to `jax.numpy.log`, `jax.numpy.log2`, or `jax.numpy.log10`,
    so the semantics, broadcasting rules, etc. are the same.  See the JAX docs for
    details:
        https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log.html
        https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log2.html
        https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log10.html

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The logarithm of the input signal.

    Parameters:
        base:
            One of "natural", "2", or "10". Determines the base of the logarithm.
            The default is "natural".
    """

    @parameters(static=["base"])
    def __init__(self, base="natural", **kwargs):
        super().__init__(None, **kwargs)

    def initialize(self, base="natural"):
        func_lookup = {
            "10": cnp.log10,
            "2": cnp.log2,
            "natural": cnp.log,
        }
        if base not in func_lookup:
            # cannot pass system=self because this error must be raised BEFORE calling super.__init__()
            # in the case of inheritting from FeedthroughBlock.
            # if we call super.__init__() first, we get missing key error for func_lookup[base].
            raise BlockParameterError(
                message=f"Logarithm block {self.name} has invalid selection {base} for 'base'. Valid selections: "
                + ", ".join([k for k in func_lookup.keys()]),
                parameter_name="base",
            )
        self.replace_op(func_lookup[base])

LogicalOperator

Bases: LeafSystem

Apply a boolean function elementwise to the input signals.

This block implements the following boolean functions
  • "or": same as np.logical_or
  • "and": same as np.logical_and
  • "not": same as np.logical_not
  • "nor": equivalent to np.logical_not(np.logical_or(in_0,in_1))
  • "nand": equivalent to np.logical_not(np.logical_and(in_0,in_1))
  • "xor": same as np.logical_xor
Input ports

(0,1) The input signals. If numeric, they are interpreted as boolean types (so 0 is False and any other value is True).

Output ports

(0) The result of the logical operation, a boolean-valued signal.

Parameters:

Name Type Description Default
function

The boolean function to apply. One of "or", "and", "not", "nor", "nand", or "xor".

required
Events

An event is triggered when the output changes from True to False or vice versa.

Source code in collimator/library/primitives.py
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
class LogicalOperator(LeafSystem):
    """Apply a boolean function elementwise to the input signals.

    This block implements the following boolean functions:
        - "or": same as np.logical_or
        - "and": same as np.logical_and
        - "not": same as np.logical_not
        - "nor": equivalent to np.logical_not(np.logical_or(in_0,in_1))
        - "nand": equivalent to np.logical_not(np.logical_and(in_0,in_1))
        - "xor": same as np.logical_xor

    Input ports:
        (0,1) The input signals.  If numeric, they are interpreted as boolean
            types (so 0 is False and any other value is True).

    Output ports:
        (0) The result of the logical operation, a boolean-valued signal.

    Parameters:
        function:
            The boolean function to apply. One of "or", "and", "not", "nor", "nand",
            or "xor".

    Events:
        An event is triggered when the output changes from True to False or vice versa.
    """

    @parameters(static=["function"])
    def __init__(self, function, **kwargs):
        super().__init__(**kwargs)
        self.declare_input_port()
        if not function == "not":
            self.declare_input_port()
        self._output_port_idx = self.declare_output_port(
            None,
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
            requires_inputs=True,
        )

    def initialize(self, function):
        self.function = function
        func_lookup = {
            "or": self._or,
            "and": self._and,
            "not": self._not,
            "xor": self._xor,
            "nor": self._nor,
            "nand": self._nand,
        }
        if function not in func_lookup:
            raise BlockParameterError(
                message=f"LogicalOperator block {self.name} has invalid selection {function} for 'function'. Valid options: "
                + ", ".join([f for f in func_lookup.keys()]),
                system=self,
            )

        if function != "not" and len(self.input_ports) < 2:
            raise BlockParameterError(
                message=f"Can't change logical operator from 'not' to {function} for block {self.name}",
                system=self,
            )

        if function == "not" and len(self.input_ports) > 1:
            raise BlockParameterError(
                message=f"Can't change logical operator from {function} to 'not' for block {self.name}",
                system=self,
            )

        self._func = func_lookup[function]

        self.configure_output_port(
            self._output_port_idx,
            self._func,
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
            requires_inputs=True,
        )

    def _edge_detection(self, time, state, *inputs, **params):
        outp = self._func(time, state, *inputs, **params)
        return cnp.where(outp, 1.0, -1.0)

    def _or(self, time, state, *inputs, **parameters):
        return cnp.logical_or(cnp.array(inputs[0]), cnp.array(inputs[1]))

    def _and(self, time, state, *inputs, **parameters):
        return cnp.logical_and(cnp.array(inputs[0]), cnp.array(inputs[1]))

    def _not(self, time, state, *inputs, **parameters):
        (x,) = inputs
        return cnp.logical_not(cnp.array(x))

    def _xor(self, time, state, *inputs, **parameters):
        return cnp.logical_xor(cnp.array(inputs[0]), cnp.array(inputs[1]))

    def _nor(self, time, state, *inputs, **parameters):
        return cnp.logical_not(
            cnp.logical_or(cnp.array(inputs[0]), cnp.array(inputs[1]))
        )

    def _nand(self, time, state, *inputs, **parameters):
        return cnp.logical_not(
            cnp.logical_and(cnp.array(inputs[0]), cnp.array(inputs[1]))
        )

    def initialize_static_data(self, context):
        # Add a zero-crossing event so ODE solvers can't try to integrate
        # through a discontinuity.  For efficiency, only do this if the output
        # is fed to an ODE block
        if not self.has_zero_crossing_events and is_discontinuity(self.output_ports[0]):
            self.declare_zero_crossing(self._edge_detection, direction="crosses_zero")

        return super().initialize_static_data(context)

LogicalReduce

Bases: FeedthroughBlock

Apply a boolean reduce function to the elements of the input signal.

This block implements the following boolean functions
  • "any": Output is True if any input element is True.
  • "all": Output is True if all input elements are True.
Input ports

(0) The input signal. If numeric, they are interpreted as boolean types (so 0 is False and any other value is True).

Output ports

(0) The result of the logical operation, a boolean-valued signal.

Parameters:

Name Type Description Default
function

The boolean function to apply. One of "any", "all".

required
axis

Axis or axes along which a logical OR/AND reduction is performed.

None
Events

An event is triggered when the output changes from True to False or vice versa.

Source code in collimator/library/primitives.py
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
class LogicalReduce(FeedthroughBlock):
    """Apply a boolean reduce function to the elements of the input signal.

    This block implements the following boolean functions:
        - "any": Output is True if any input element is True.
        - "all": Output is True if all input elements are True.

    Input ports:
        (0) The input signal.  If numeric, they are interpreted as boolean
            types (so 0 is False and any other value is True).

    Output ports:
        (0) The result of the logical operation, a boolean-valued signal.

    Parameters:
        function:
            The boolean function to apply. One of "any", "all".
        axis:
            Axis or axes along which a logical OR/AND reduction is performed.

    Events:
        An event is triggered when the output changes from True to False or vice versa.
    """

    @parameters(static=["function", "axis"])
    def __init__(self, function, axis=None, **kwargs):
        super().__init__(None, **kwargs)

    def initialize(self, function, axis=None):
        self.function = function
        self.axis = int(axis) if axis is not None else None
        func_lookup = {
            "any": self._any,
            "all": self._all,
        }
        if function not in func_lookup:
            raise BlockParameterError(
                message=f"LogicalReduce block {self.name} has invalid selection {function} for 'function'. Valid options: "
                + ", ".join([f for f in func_lookup.keys()])
            )

        self._func = func_lookup[function]
        self.replace_op(self._func)

    def _edge_detection(self, _time, _state, *inputs, **_params):
        outp = self._func(inputs)
        return cnp.where(outp, 1.0, -1.0)

    def _any(self, inputs):
        return cnp.any(cnp.array(inputs), axis=self.axis)

    def _all(self, inputs):
        return cnp.all(cnp.array(inputs), axis=self.axis)

    def initialize_static_data(self, context):
        # Add a zero-crossing event so ODE solvers can't try to integrate
        # through a discontinuity.  For efficiency, only do this if the output
        # is fed to an ODE block
        if not self.has_zero_crossing_events and is_discontinuity(self.output_ports[0]):
            self.declare_zero_crossing(self._edge_detection, direction="crosses_zero")

        return super().initialize_static_data(context)

LookupTable1d

Bases: FeedthroughBlock

Interpolate the input signal into a static lookup table.

If a function y = f(x) is sampled at a set of points (x_i, y_i), then this block will interpolate the input signal x to compute the output signal y. The behavior is modeled after scipy.interpolate.interp1d but is implemented in JAX. Available interpolation modes are: - "linear": Linear interpolation using jax.interp. - "nearest": Nearest-neighbor interpolation. - "flat": Flat interpolation.

Input ports

(0) The input signal, which is used as the interpolation coordinate.

Output ports

(0) The interpolated output signal.

Parameters:

Name Type Description Default
input_array

The array of input values at which the output values are provided.

required
output_array

The array of output values.

required
interpolation

One of "linear", "nearest", or "flat". Determines the type of interpolation performed by the block.

required
Notes

Currently restricted to 1D input and output data. This may be expanded to support multi-dimensional output arrays in the future.

Source code in collimator/library/primitives.py
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
class LookupTable1d(FeedthroughBlock):
    """Interpolate the input signal into a static lookup table.

    If a function `y = f(x)` is sampled at a set of points `(x_i, y_i)`, then this
    block will interpolate the input signal `x` to compute the output signal `y`.
    The behavior is modeled after `scipy.interpolate.interp1d` but is implemented
    in JAX.  Available interpolation modes are:
        - "linear": Linear interpolation using `jax.interp`.
        - "nearest": Nearest-neighbor interpolation.
        - "flat": Flat interpolation.

    Input ports:
        (0) The input signal, which is used as the interpolation coordinate.

    Output ports:
        (0) The interpolated output signal.

    Parameters:
        input_array:
            The array of input values at which the output values are provided.
        output_array:
            The array of output values.
        interpolation:
            One of "linear", "nearest", or "flat". Determines the type of interpolation
            performed by the block.

    Notes:
        Currently restricted to 1D input and output data.  This may be expanded to
        support multi-dimensional output arrays in the future.
    """

    @parameters(static=["input_array", "output_array", "interpolation"])
    def __init__(self, input_array, output_array, interpolation, **kwargs):
        super().__init__(None, **kwargs)

    def initialize(self, input_array, output_array, interpolation):
        self.input_array = cnp.array(input_array)
        self.output_array = cnp.array(output_array)
        if len(self.input_array.shape) != 1:
            raise ValueError(
                f"LookupTable1d block {self.name} input_array must be 1D, got shape "
                f"{self.input_array.shape}"
            )
        if len(self.output_array.shape) != 1:
            raise ValueError(
                f"LookupTable1d block {self.name} output_array must be 1D, got shape "
                f"{self.output_array.shape}"
            )
        self.max_i = len(self.input_array) - 1

        func_lookup = {
            "linear": self._lookup_linear,
            "nearest": self._lookup_nearest,
            "flat": self._lookup_flat,
        }
        if interpolation not in func_lookup:
            raise ValueError(
                f"LookupTable1d block {self.name} has invalid selection {interpolation} "
                "for 'interpolation'"
            )
        self.replace_op(func_lookup[interpolation])

    def _lookup_linear(self, x):
        return cnp.interp(x, self.input_array, self.output_array)

    def _lookup_nearest(self, x):
        i = cnp.argmin(cnp.abs(self.input_array - x))
        i = cnp.clip(i, 0, self.max_i)
        return self.output_array[i]

    def _lookup_flat(self, x):
        i = cnp.where(
            x < self.input_array[1],
            0,
            cnp.argmin(x >= self.input_array) - 1,
        )
        return self.output_array[i]

LookupTable2d

Bases: LeafSystem

Interpolate the input signals into a static lookup table.

The behavior is modeled on scipy.interpolate.interp2d but is implemented in JAX. The only currently implemented interpolation mode is "linear". The input arrays must be 1D and the output array must be 2D.

Input ports

(0) The first input signal, used as the first interpolation coordinate. (1) The second input signal, used as the second interpolation coordinate.

Output ports

(0) The interpolated output signal.

Parameters:

Name Type Description Default
input_x_array

The array of input values at which the output values are provided, corresponding to the first input signal. Must be 1D

required
input_y_array

The array of input values at which the output values are provided, corresponding to the second input signal. Must be 1D

required
output_table_array

The array of output values. Must be 2D with shape (m, n), where m = len(input_x_array) and n = len(input_y_array).

required
interpolation

Only "linear" is supported.

'linear'
Source code in collimator/library/primitives.py
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
class LookupTable2d(LeafSystem):
    """Interpolate the input signals into a static lookup table.

    The behavior is modeled on `scipy.interpolate.interp2d` but is implemented
    in JAX.  The only currently implemented interpolation mode is "linear". The
    input arrays must be 1D and the output array must be 2D.

    Input ports:
        (0) The first input signal, used as the first interpolation coordinate.
        (1) The second input signal, used as the second interpolation coordinate.

    Output ports:
        (0) The interpolated output signal.

    Parameters:
        input_x_array:
            The array of input values at which the output values are provided,
            corresponding to the first input signal. Must be 1D
        input_y_array:
            The array of input values at which the output values are provided,
            corresponding to the second input signal. Must be 1D
        output_table_array:
            The array of output values. Must be 2D with shape `(m, n)`, where
            `m = len(input_x_array)` and `n = len(input_y_array)`.
        interpolation:
            Only "linear" is supported.
    """

    @parameters(
        static=["input_x_array", "input_y_array", "output_table_array", "interpolation"]
    )
    def __init__(
        self,
        input_x_array,
        input_y_array,
        output_table_array,
        interpolation="linear",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.declare_input_port()
        self.declare_input_port()
        self._output_port_idx = self.declare_output_port(
            None,
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
            requires_inputs=True,
        )

    def initialize(
        self, input_x_array, input_y_array, output_table_array, interpolation
    ):
        xp = cnp.array(input_x_array)
        yp = cnp.array(input_y_array)
        zp = cnp.array(output_table_array)

        if len(xp.shape) != 1:
            raise ValueError(
                f"LookupTable2d block {self.name} input_x_array must be 1D, got "
                f"shape {xp.shape}"
            )

        if len(yp.shape) != 1:
            raise ValueError(
                f"LookupTable2d block {self.name} input_y_array must be 1D, got "
                f"shape {yp.shape}"
            )

        if len(zp.shape) != 2:
            raise ValueError(
                f"LookupTable2d block {self.name} output_table_array must be 2D, "
                f"got shape {zp.shape}"
            )

        if zp.shape != (len(xp), len(yp)):
            raise ValueError(
                f"LookupTable2d block {self.name} output_table_array must have "
                f"shape (len(input_x_array), len(input_y_array)), got shape {zp.shape}"
            )

        if interpolation != "linear":
            raise NotImplementedError(
                f"LookupTable2d block {self.name} only supports linear interpolation."
            )

        self._compute_output = partial(cnp.interp2d, xp, yp, zp)

        self.configure_output_port(
            self._output_port_idx,
            self._output,
            prerequisites_of_calc=[port.ticket for port in self.input_ports],
            requires_inputs=True,
        )

    def _output(self, _time, _state, *inputs, **params):
        (x, y) = inputs
        return self._compute_output(x, y)

MJX

Bases: MuJoCoBase

A system that wraps a MuJoCo model and provides a continuous-time ODE LeafSystem. Currently only supports a single body system.

Input ports

(0) The control input vector control.

Output ports

(0) The generalized position coordinates qpos. (1) The generalized velocity coordinates qvel. (2) The actuator coordinates act. (3) The sensor data sensor_data (if enabled). (4) The video output video as RGB frames of shape (H,W,3) (if enabled). (5) A fake output port, present only if vHIL=True and outputs Array(0.0). (6+) Custom output ports, defined with user-specified python scripts.

Parameters:

Name Type Description Default
file_name str

The path to the MuJoCo XML model file.

required
dt float

If None, collimator's internal solver will be used and this block can be considered as a continuous block. If set, the model will be run in a discrete mode with the specified timestep, using MJX's solver, more like Co-Simulation. In that case, it might be favorable to set use_mjx=False.

None
key_frame_0 int | str

The keyframe to initialize the model from.

None
qpos_0 Array

The initial generalized position coordinates.

None
qvel_0 Array

The initial generalized velocity coordinates.

None
act_0 Array

The initial actuator coordinates.

None
enable_sensor_data bool

Whether to output the sensor data to an optional port named 'sensor_data'.

False
enable_video_output bool

Whether to output the rendered video frames to an optional port named 'video'.

False
video_size tuple[int, int]

The size of the video output frames as a (H,W) tuple.

None
enable_mocap_pos bool

Whether to enable the mocap_pos input port for motion capture tracking.

False
vHIL bool

Whether to run in virtual hardware-in-the-loop mode.

False
vHIL_dt float

The timestep for the virtual hardware-in-the-loop mode.

0.01

Notes: (i) _model and _data refer to MuJoCo's mjModel and mjData objects respectively. model and data are the corresponding MJX objects. (ii) While sensordata output is supported as a pure callback to MuJoCo since MJX has not yet implemented this aspect. This can be expensive.

Source code in collimator/library/mujoco.py
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
class MJX(MuJoCoBase):
    """
    A system that wraps a MuJoCo model and provides a continuous-time ODE LeafSystem.
    Currently only supports a single body system.

    Input ports:
        (0) The control input vector `control`.

    Output ports:
        (0) The generalized position coordinates `qpos`.
        (1) The generalized velocity coordinates `qvel`.
        (2) The actuator coordinates `act`.
        (3) The sensor data `sensor_data` (if enabled).
        (4) The video output `video` as RGB frames of shape (H,W,3) (if enabled).
        (5) A fake output port, present only if `vHIL=True` and outputs Array(0.0).
        (6+) Custom output ports, defined with user-specified python scripts.

    Parameters:
        file_name (str):
            The path to the MuJoCo XML model file.

        dt (float, optional):
            If None, collimator's internal solver will be used and this block can
            be considered as a continuous block. If set, the model will be run in
            a discrete mode with the specified timestep, using MJX's solver, more like
            Co-Simulation. In that case, it might be favorable to set `use_mjx=False`.

        key_frame_0 (int|str, optional):
            The keyframe to initialize the model from.

        qpos_0 (Array, optional):
            The initial generalized position coordinates.

        qvel_0 (Array, optional):
            The initial generalized velocity coordinates.

        act_0 (Array, optional):
            The initial actuator coordinates.

        enable_sensor_data (bool, optional):
            Whether to output the sensor data to an optional port named 'sensor_data'.

        enable_video_output (bool, optional):
            Whether to output the rendered video frames to an optional port named 'video'.

        video_size (tuple[int, int], optional):
            The size of the video output frames as a (H,W) tuple.

        enable_mocap_pos (bool, optional):
            Whether to enable the mocap_pos input port for motion capture tracking.

        vHIL (bool, optional):
            Whether to run in virtual hardware-in-the-loop mode.

        vHIL_dt (float, optional):
            The timestep for the virtual hardware-in-the-loop mode.

    Notes:
    (i) `_model` and `_data` refer to MuJoCo's `mjModel` and `mjData` objects
    respectively. `model` and `data` are the corresponding MJX objects.
    (ii) While `sensordata` output is supported as a pure callback to MuJoCo since MJX
    has not yet implemented this aspect. This can be expensive.
    """

    @parameters(
        static=[
            "file_name",
            "dt",
            "key_frame_0",
            "qpos_0",
            "qvel_0",
            "act_0",
            "enable_sensor_data",
            "enable_video_output",
            "video_size",
            "enable_mocap_pos",
            "vHIL",
            "vHIL_dt",
        ]
    )
    def __init__(
        self,
        file_name: str,
        dt: float = None,
        key_frame_0: int | str = None,
        qpos_0: Array = None,
        qvel_0: Array = None,
        act_0: Array = None,
        enable_sensor_data=False,
        enable_video_output=False,
        video_size: tuple[int, int] = None,
        enable_mocap_pos=False,
        custom_output_scripts: dict[str, str] = None,
        vHIL=False,
        vHIL_dt=0.01,
        **kwargs,
    ):
        super().__init__(
            use_mjx=True,
            file_name=file_name,
            dt=dt,
            key_frame_0=key_frame_0,
            qpos_0=qpos_0,
            qvel_0=qvel_0,
            act_0=act_0,
            enable_sensor_data=enable_sensor_data,
            enable_video_output=enable_video_output,
            video_size=video_size,
            enable_mocap_pos=enable_mocap_pos,
            custom_output_scripts=custom_output_scripts,
            vHIL=vHIL,
            vHIL_dt=vHIL_dt,
            **kwargs,
        )

        try:
            self.model = mjx.put_model(self._model)
            self.data = mjx.put_data(self._model, self._data)
        except NotImplementedError as e:
            logger.error(
                "This robot model uses features not implemented in MJX. "
                "Please try the MuJoCo block instead (toggle use_mjx to false), "
                "or modify the MJCF file.",
                **logdata(block=self, exception=f"{type(e).__name__}: {str(e)}"),
            )
            raise e

        if dt is None or dt == 0:
            self.dt = None
            logger.info(
                "MuJoCo MJX block is running in continuous mode and will use "
                "Collimator's solver.",
                **logdata(block=self),
            )

            state_0 = jnp.concatenate([self.qpos_0, self.qvel_0, self.act_0])
            self.declare_continuous_state(ode=self._ode, default_value=state_0)

        else:
            self.dt = dt
            known_solver_names = {
                mujoco.mjtSolver.mjSOL_CG: "Conjugate Gradient",
                mujoco.mjtSolver.mjSOL_NEWTON: "Newton",
            }
            logger.info(
                "MuJoCo MJX block is running in discrete mode with dt=%s, "
                "this will use MJX's solver '%s' and not Collimator's solver.",
                dt,
                known_solver_names.get(
                    self.model.opt.solver,
                    str(self.model.opt.solver),
                ),
                **logdata(block=self),
            )

            callback_index = self.declare_cache(
                self._step_cache_cb,
                default_value=self.data,
                period=dt,
                offset=0.0,
                requires_inputs=True,
            )
            self.mjx_data_cache_index = self.callbacks[callback_index].cache_index

        self.declare_output_port(
            self._output_qpos,
            default_value=self.qpos_0,
            requires_inputs=False,
            name="qpos",
        )

        self.declare_output_port(
            self._output_qvel,
            default_value=self.qvel_0,
            requires_inputs=False,
            name="qvel",
        )

        self.declare_output_port(
            self._output_act,
            default_value=self.act_0,
            requires_inputs=False,
            name="act",
        )

        if enable_sensor_data:
            logger.warning(
                "Sensor data output with MJX might be very slow. Consider switching "
                "to the non-MJX MuJoCo block for better performance.",
                **logdata(block=self),
            )
            self._declare_sensor_data_port(self.dt)

        if enable_video_output:
            logger.warning(
                "Video output with MJX might be very slow. Consider switching "
                "to the non-MJX MuJoCo block for better performance.",
                **logdata(block=self),
            )
            self._declare_video_output_port(video_size)

        if vHIL:
            self._declare_vhil_fake_output_port(vHIL_dt)

        self._declare_custom_output_ports(custom_output_scripts, self.dt)

    def _cached_data(self, state: LeafState) -> mjxData:
        return state.cache[self.mjx_data_cache_index]

    def _qpos(self, state: LeafState):
        if self.dt is not None:
            return self._cached_data(state).qpos

        return state.continuous_state[self.qpos_start : self.qpos_end]

    def _qvel(self, state: LeafState):
        if self.dt is not None:
            return self._cached_data(state).qvel

        return state.continuous_state[self.qvel_start : self.qvel_end]

    def _act(self, state: LeafState):
        if self.dt is not None:
            return self._cached_data(state).act

        return state.continuous_state[self.act_start : self.act_end]

    def _ode(self, time, state, *inputs, **parameters):
        # Implementation of the ODE when running model in continuous mode, with
        # collimator's internal solver.

        qpos = self._qpos(state)
        qvel = self._qvel(state)
        act = self._act(state)

        ctrl = inputs[0]

        # FIXME: Should we normalize the quaternions here to avoid numerical drift?
        # qpos = self.normalize_qpos_quat(qpos)

        model, data = self.model, self.data
        data = data.replace(time=time, qpos=qpos, qvel=qvel, act=act, ctrl=ctrl)

        data = mjx.forward(model, data)

        qvel_dot = data.qacc
        qpos_dot = position_derivatives(model.jnt_type, qpos, qvel)
        act_dot = data.act_dot

        state_dot = jnp.concatenate([qpos_dot, qvel_dot, act_dot])

        return state_dot

    def _step_cache_cb(self, time, state: LeafState, *inputs, **parameters):
        # Implementation of the ODE when running model in discrete mode, with
        # MJX's solver. This is like the non-MJX variant or an FMU.

        # TODO: try a version wrapped with io_callback to compare
        # compilation times. Splitting the compute graph between collimator
        # and mjx could bring improvements, but quite obviously at the cost
        # of any usefulness of mjx over mujoco (autodiff, vmap, ...).

        ctrl = inputs[0]

        data = self._cached_data(state)
        data = data.replace(time=time, ctrl=ctrl)
        data = mjx.step(self.model, data)

        return data

    def _output_qpos(self, time, state, *inputs, **parameters):
        qpos = self._qpos(state)
        qpos_normalized_quats = self.normalize_qpos_quat(qpos)
        return qpos_normalized_quats

    def _output_qvel(self, time, state, *inputs, **parameters):
        return self._qvel(state)

    def _output_act(self, time, state, *inputs, **parameters):
        return self._act(state)

    def _mj_forward(self, time, qpos, qvel, act, ctrl=None):
        data = self._data
        data.time = time
        data.qpos[:] = qpos
        data.qvel[:] = qvel
        data.act[:] = act
        if ctrl is not None:
            data.ctrl[:] = ctrl
        mujoco.mj_forward(self._model, data)
        return data

    def _pure_callback_sensordata(self, time, qpos, qvel, act, ctrl):
        data = self._mj_forward(time, qpos, qvel, act, ctrl)
        return data.sensordata

    def _output_sensor_data(self, time, state, *inputs, **parameters):
        qpos = self._qpos(state)
        qvel = self._qvel(state)
        act = self._act(state)
        ctrl = inputs[0] if inputs else None

        qpos = self.normalize_qpos_quat(qpos)

        return jax.pure_callback(
            self._pure_callback_sensordata,
            self.pure_callback_sensordata_result_type,
            time,
            qpos,
            qvel,
            act,
            ctrl,
        )

    def normalize_qpos_quat(self, qpos):
        """
        Normalize the quaternion components of the generalized position coordinates.
        """
        qpos_normalized, qi = [], 0

        for jnt_typ in self.model.jnt_type:
            if jnt_typ == mjx_types.JointType.FREE:
                trans = qpos[qi : qi + 3]
                quat = qpos[qi + 3 : qi + 7]
                norm_quat = mjx_math.normalize(quat)
                qpos_normalized.append(jnp.concatenate([trans, norm_quat]))
                qi = qi + 7
            elif jnt_typ == mjx_types.JointType.BALL:
                quat = qpos[qi : qi + 4]
                norm_quat = mjx_math.normalize(quat)
                qpos_normalized.append(norm_quat)
                qi = qi + 4
            elif jnt_typ in (mjx_types.JointType.HINGE, mjx_types.JointType.SLIDE):
                trans = qpos[qi]
                qpos_normalized.append(trans[None])
                qi = qi + 1
            else:
                raise RuntimeError(f"unrecognized joint type: {jnt_typ}")

        return jnp.concatenate(qpos_normalized) if qpos_normalized else jnp.empty((0,))

    def _update_viewer(self, time, state, *inputs, **parameters):
        ctrl = inputs[0]
        return jax.pure_callback(
            self._pure_callback_update_viewer,
            self.pure_callback_update_result_type,
            ctrl,
        )

    def _pure_callback_update_viewer(self, ctrl):
        if self.viewer.is_running:
            self.viewer.sync()
            self.data_vhil.ctrl[:] = ctrl
            mujoco.mj_step(self.model_vhil, self.data_vhil)
        return jnp.array(0.0)

    def _output_video_discrete(self, time, state, *inputs, **parameters):
        def _discrete_cb(state):
            mjx_data = self._cached_data(state)
            data = mjx.get_data(self._model, mjx_data)
            if self.enable_mocap_pos:
                data.mocap_pos[:] = inputs[1]
            return self.render(data)

        return io_callback(_discrete_cb, self._video_default, time, state)

    def _output_video(self, time, state, *inputs, **parameters):
        if self.dt is not None:
            return self._output_video_discrete(time, state, *inputs, **parameters)

        def _continuous_cb(time, qpos, qvel, act, inputs):
            data = self._mj_forward(time, qpos, qvel, act)
            if self.enable_mocap_pos:
                data.mocap_pos[:] = inputs[1]
            return self.render(data)

        qpos = self._qpos(state)
        qvel = self._qvel(state)
        act = self._act(state)

        qpos = self.normalize_qpos_quat(qpos)

        return io_callback(
            _continuous_cb, self._video_default, time, qpos, qvel, act, inputs
        )

normalize_qpos_quat(qpos)

Normalize the quaternion components of the generalized position coordinates.

Source code in collimator/library/mujoco.py
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
def normalize_qpos_quat(self, qpos):
    """
    Normalize the quaternion components of the generalized position coordinates.
    """
    qpos_normalized, qi = [], 0

    for jnt_typ in self.model.jnt_type:
        if jnt_typ == mjx_types.JointType.FREE:
            trans = qpos[qi : qi + 3]
            quat = qpos[qi + 3 : qi + 7]
            norm_quat = mjx_math.normalize(quat)
            qpos_normalized.append(jnp.concatenate([trans, norm_quat]))
            qi = qi + 7
        elif jnt_typ == mjx_types.JointType.BALL:
            quat = qpos[qi : qi + 4]
            norm_quat = mjx_math.normalize(quat)
            qpos_normalized.append(norm_quat)
            qi = qi + 4
        elif jnt_typ in (mjx_types.JointType.HINGE, mjx_types.JointType.SLIDE):
            trans = qpos[qi]
            qpos_normalized.append(trans[None])
            qi = qi + 1
        else:
            raise RuntimeError(f"unrecognized joint type: {jnt_typ}")

    return jnp.concatenate(qpos_normalized) if qpos_normalized else jnp.empty((0,))

MLP

Bases: FeedthroughBlock

A feedforward neural network block representing an Equinox multi-layer perceptron (MLP). The output y of the MLP is computed as

    y = MLP(x, theta)

where theta are the parameters of the MLP, and x is the input to the MLP. This block is differentialble w.r.t. the MLP parameters theta. Note that theta, does not include the hyperparameters representing the architecture of the MLP.

Input ports

(0) The input to the MLP.

Output ports

(0) The output of the MLP.

Parameters:

Name Type Description Default
in_size int

The dimension of the input to the MLP.

None
out_size int

The dimension of the output of the MLP.

None
width_size int

The width of every hidden layers of the MLP.

None
depth int

The depth of the MLP. This represents the number of hidden layers, including the output layer.

None
seed int

The seed for the random number generator for initialization of the MLP parameters (weights and biases of every layer). If None, a random 32-bit seed will be generated.

None
activation_str str

The activation function to use after each internal layer of the MLP. Possible values are "relu", "sigmoid", "tanh", "elu", "swish", "rbf", and "identity". Default is "relu".

'relu'
final_activation_str str

The activation function to use for the output layer of the MLP. Possible values are "relu", "sigmoid", "tanh", "elu", "swish", "rbf", and "identity". Default is "identity".

'identity'
use_bias bool

Whether to add a bias to the internal layers of the MLP. Default is True.

True
use_final_bias bool

Wheter to add a bias to the output layer of the MLP. Default is True.

True
file_name str

Optional file name containing the serialized parameters of the MLP. If provided, the parameters are loaded from the file, and set as the parameters of the MLP. Default is None.

None
Source code in collimator/library/nn.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class MLP(FeedthroughBlock):
    """
    A feedforward neural network block representing an Equinox multi-layer
    perceptron (MLP). The output `y` of the MLP is computed as

    ```
        y = MLP(x, theta)
    ```

    where `theta` are the parameters of the MLP, and `x` is the input to the MLP.
    This block is differentialble w.r.t. the MLP parameters `theta`. Note that `theta`,
    does not include the hyperparameters representing the architecture of the MLP.

    Input ports:
        (0) The input to the MLP.

    Output ports:
        (0) The output of the MLP.

    Parameters:
        in_size (int):
            The dimension of the input to the MLP.
        out_size (int):
            The dimension of the output of the MLP.
        width_size (int):
            The width of every hidden layers of the MLP.
        depth (int):
            The depth of the MLP. This represents the number of hidden layers,
            including the output layer.
        seed (int):
            The seed for the random number generator for initialization of the
            MLP parameters (weights and biases of every layer).
            If None, a random 32-bit seed will be generated.
        activation_str (str):
            The activation function to use after each internal layer of the MLP.
            Possible values are "relu", "sigmoid", "tanh", "elu", "swish", "rbf",
            and "identity". Default is "relu".
        final_activation_str (str):
            The activation function to use for the output layer of the MLP.
            Possible values are "relu", "sigmoid", "tanh", "elu", "swish", "rbf",
            and "identity". Default is "identity".
        use_bias (bool):
            Whether to add a bias to the internal layers of the MLP.
            Default is True.
        use_final_bias (bool):
            Wheter to add a bias to the output layer of the MLP.
            Default is True.
        file_name (str):
            Optional file name containing the serialized parameters of the MLP.
            If provided, the parameters are loaded from the file, and set as the
            parameters of the MLP. Default is None.
    """

    @parameters(
        static=[
            "in_size",
            "out_size",
            "width_size",
            "depth",
            "seed",
            "activation_str",
            "final_activation_str",
            "use_bias",
            "use_final_bias",
            "file_name",
        ],
    )
    def __init__(
        self,
        in_size=None,
        out_size=None,
        width_size=None,
        depth=None,
        seed=None,
        activation_str="relu",
        final_activation_str="identity",
        use_bias=True,
        use_final_bias=True,
        file_name=None,
        **kwargs,
    ):
        """
        see https://docs.kidger.site/equinox/examples/serialisation/ for rationale
        of implementation here. We can't serialize the activation function, so we
        serialize a string representing a selection for activation function amongst
        a finite set of options.
        """
        super().__init__(None, **kwargs)

    def initialize(
        self,
        in_size=None,
        out_size=None,
        width_size=None,
        depth=None,
        seed=None,
        activation_str="relu",
        final_activation_str="identity",
        use_bias=True,
        use_final_bias=True,
        file_name=None,
        mlp_params=None,
    ):
        # FIXME: mlp_params will always be overwritten so it can't be optimized for now.

        if in_size is None or out_size is None or width_size is None or depth is None:
            raise ValueError("Must specify in_size, out_size, width_size, and depth.")
        else:
            # Cast to int for safety
            in_size = int(in_size)
            out_size = int(out_size)
            width_size = int(width_size)
            depth = int(depth)

        # file_name may come as an empty string through json parsing
        if file_name == "":
            file_name = None

        # dict maping activation string to function
        # TODO: Add more activation functions from
        # https://jax.readthedocs.io/en/latest/jax.nn.html and updae schema
        def _match_activation(activation_str):
            activation_mapping = {
                "relu": jax.nn.relu,
                "sigmoid": jax.nn.sigmoid,
                "tanh": jnp.tanh,
                "elu": jax.nn.elu,
                "swish": jax.nn.silu,
                "rbf": lambda x: jnp.exp(-(x**2)),
                "identity": lambda x: x,
            }
            if activation_str not in activation_mapping:
                warnings.warn(
                    f"Provided activation function {activation_str} not recognized. "
                    "Using Identity function as activation."
                )
            return activation_mapping.get(activation_str, lambda x: x)

        seed = (
            np.random.randint(0, 2**32, dtype=np.int64) if seed is None else int(seed)
        )
        self.key = random.PRNGKey(seed)

        self.mlp = eqx.nn.MLP(
            in_size,
            out_size,
            width_size,
            depth,
            key=self.key,
            activation=_match_activation(activation_str),
            final_activation=_match_activation(final_activation_str),
            use_bias=use_bias,
            use_final_bias=use_final_bias,
        )

        if file_name is not None:
            with open(file_name, "rb") as fp:
                self.mlp = eqx.tree_deserialise_leaves(fp, self.mlp)

        # partition into a pytree of params and static components
        mlp_params, self.mlp_static = eqx.partition(self.mlp, eqx.is_array)

        if "mlp_params" in self.dynamic_parameters:
            self.dynamic_parameters["mlp_params"].set(mlp_params)
        else:
            self.declare_dynamic_parameter("mlp_params", mlp_params, as_array=False)

        def _eval_MLP(inputs, **parameters):
            mlp_params = parameters["mlp_params"]
            mlp = eqx.combine(mlp_params, self.mlp_static)
            return mlp(inputs)

        self.replace_op(_eval_MLP)

    def serialize(self, file_name, mlp_params=None):
        """
        Serialize only the parameters of the MLP. Note that the hyperparameters
        representing the architecture of the MLP are not serialized. This is because
        of the following use-cases imagined:
        (i) The user may train the Equinox MLP outside of Collimator. In this case,
        it seems unnecessary to force the user to serialize the hyperparameters of the
        MLP in the strict form chosen by Collimator. It would seem much easier
        for the user to just input these hyperparameters when creating the MLP block
        in Collimator UI, and upload the naturally produced serialized parameters file
        by Equinox.
        (ii) The user may want to train the Equinox MLP within Collimator in a notebook,
        and then use the block within Colimator UI. In this case, while serialization of
        the hyperparameters of the MLP would be a litte more convenient compared
        to manually inputting the hyperparameters in the UI, it seems like a small
        convenience relative to disadvantages of (i). Ideally the user should be
        able to use the API to push the learnt parameters.
        (iii) When we support training in the UI, the hyperparameters are naturally
        serialzed with `declare_configuraton_parameters`, and thus, in this case too,
        only serializatio of the MLP parameters is necessary.

        The choice of an optional `mlp_params` is to enable training of the
        models in a notebook and easily seralizing them for use in the UI.
        """
        self.create_context()

        if mlp_params is None:
            mlp = self.mlp
        else:
            mlp = eqx.combine(mlp_params, self.mlp_static)
        with open(file_name, "wb") as f:
            eqx.tree_serialise_leaves(f, mlp)

__init__(in_size=None, out_size=None, width_size=None, depth=None, seed=None, activation_str='relu', final_activation_str='identity', use_bias=True, use_final_bias=True, file_name=None, **kwargs)

see https://docs.kidger.site/equinox/examples/serialisation/ for rationale of implementation here. We can't serialize the activation function, so we serialize a string representing a selection for activation function amongst a finite set of options.

Source code in collimator/library/nn.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@parameters(
    static=[
        "in_size",
        "out_size",
        "width_size",
        "depth",
        "seed",
        "activation_str",
        "final_activation_str",
        "use_bias",
        "use_final_bias",
        "file_name",
    ],
)
def __init__(
    self,
    in_size=None,
    out_size=None,
    width_size=None,
    depth=None,
    seed=None,
    activation_str="relu",
    final_activation_str="identity",
    use_bias=True,
    use_final_bias=True,
    file_name=None,
    **kwargs,
):
    """
    see https://docs.kidger.site/equinox/examples/serialisation/ for rationale
    of implementation here. We can't serialize the activation function, so we
    serialize a string representing a selection for activation function amongst
    a finite set of options.
    """
    super().__init__(None, **kwargs)

serialize(file_name, mlp_params=None)

Serialize only the parameters of the MLP. Note that the hyperparameters representing the architecture of the MLP are not serialized. This is because of the following use-cases imagined: (i) The user may train the Equinox MLP outside of Collimator. In this case, it seems unnecessary to force the user to serialize the hyperparameters of the MLP in the strict form chosen by Collimator. It would seem much easier for the user to just input these hyperparameters when creating the MLP block in Collimator UI, and upload the naturally produced serialized parameters file by Equinox. (ii) The user may want to train the Equinox MLP within Collimator in a notebook, and then use the block within Colimator UI. In this case, while serialization of the hyperparameters of the MLP would be a litte more convenient compared to manually inputting the hyperparameters in the UI, it seems like a small convenience relative to disadvantages of (i). Ideally the user should be able to use the API to push the learnt parameters. (iii) When we support training in the UI, the hyperparameters are naturally serialzed with declare_configuraton_parameters, and thus, in this case too, only serializatio of the MLP parameters is necessary.

The choice of an optional mlp_params is to enable training of the models in a notebook and easily seralizing them for use in the UI.

Source code in collimator/library/nn.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def serialize(self, file_name, mlp_params=None):
    """
    Serialize only the parameters of the MLP. Note that the hyperparameters
    representing the architecture of the MLP are not serialized. This is because
    of the following use-cases imagined:
    (i) The user may train the Equinox MLP outside of Collimator. In this case,
    it seems unnecessary to force the user to serialize the hyperparameters of the
    MLP in the strict form chosen by Collimator. It would seem much easier
    for the user to just input these hyperparameters when creating the MLP block
    in Collimator UI, and upload the naturally produced serialized parameters file
    by Equinox.
    (ii) The user may want to train the Equinox MLP within Collimator in a notebook,
    and then use the block within Colimator UI. In this case, while serialization of
    the hyperparameters of the MLP would be a litte more convenient compared
    to manually inputting the hyperparameters in the UI, it seems like a small
    convenience relative to disadvantages of (i). Ideally the user should be
    able to use the API to push the learnt parameters.
    (iii) When we support training in the UI, the hyperparameters are naturally
    serialzed with `declare_configuraton_parameters`, and thus, in this case too,
    only serializatio of the MLP parameters is necessary.

    The choice of an optional `mlp_params` is to enable training of the
    models in a notebook and easily seralizing them for use in the UI.
    """
    self.create_context()

    if mlp_params is None:
        mlp = self.mlp
    else:
        mlp = eqx.combine(mlp_params, self.mlp_static)
    with open(file_name, "wb") as f:
        eqx.tree_serialise_leaves(f, mlp)

MatrixConcatenation

Bases: ReduceBlock

Concatenate two matrices along a given axis.

Dispatches to jax.numpy.concatenate, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.concatenate.html

Parameters:

Name Type Description Default
axis

The axis along which the matrices are concatenated. 0 for vertical and 1 for horizontal. Default is 0.

0
Input ports

(0, 1) The input matrices A and B

Output ports

(0) The concatenation input matrices: e.g. [A,B].

Source code in collimator/library/primitives.py
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
class MatrixConcatenation(ReduceBlock):
    """Concatenate two matrices along a given axis.

    Dispatches to `jax.numpy.concatenate`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.concatenate.html

    Args:
        axis: The axis along which the matrices are concatenated. 0 for vertical
            and 1 for horizontal. Default is 0.

    Input ports:
        (0, 1) The input matrices `A` and `B`

    Output ports:
        (0) The concatenation input matrices: e.g. `[A,B]`.
    """

    @parameters(static=["axis"])
    def __init__(self, n_in=2, axis=0, **kwargs):
        if n_in != 2:
            raise ValueError(
                "MatrixConcatenation block only supports two input matrices."
            )
        super().__init__(2, None, **kwargs)

    def initialize(self, axis):
        def _func(inputs):
            return cnp.concatenate((inputs[0], inputs[1]), axis=int(axis))

        self.replace_op(_func)

MatrixInversion

Bases: FeedthroughBlock

Compute the matrix inverse of the input signal.

Dispatches to jax.numpy.inv, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html

Input ports

(0) The input matrix.

Output ports

(0) The inverse of the input matrix.

Source code in collimator/library/primitives.py
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
class MatrixInversion(FeedthroughBlock):
    """Compute the matrix inverse of the input signal.

    Dispatches to `jax.numpy.inv`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html

    Input ports:
        (0) The input matrix.

    Output ports:
        (0) The inverse of the input matrix.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(cnp.linalg.inv, *args, **kwargs)

MatrixMultiplication

Bases: ReduceBlock

Compute the matrix product of the input signals.

Dispatches to jax.numpy.matmul, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.matmul.html

Input ports

(0, 1) The input matrices A and B

Output ports

(0) The matrix product of the input matrices: A @ B.

Source code in collimator/library/primitives.py
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
class MatrixMultiplication(ReduceBlock):
    """Compute the matrix product of the input signals.

    Dispatches to `jax.numpy.matmul`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.matmul.html

    Input ports:
        (0, 1) The input matrices `A` and `B`

    Output ports:
        (0) The matrix product of the input matrices: `A @ B`.
    """

    def __init__(
        self,
        n_in=2,
        **kwargs,
    ):
        if n_in != 2:
            raise ValueError(
                "MatrixMultiplication block only supports two input signals."
            )

        def _func(inputs):
            return cnp.matmul(inputs[0], inputs[1])

        super().__init__(2, _func, **kwargs)

MatrixTransposition

Bases: FeedthroughBlock

Compute the matrix transpose of the input signal.

Dispatches to jax.numpy.transpose, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.transpose.html

Input ports

(0) The input matrix.

Output ports

(0) The transpose of the input matrix.

Source code in collimator/library/primitives.py
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
class MatrixTransposition(FeedthroughBlock):
    """Compute the matrix transpose of the input signal.

    Dispatches to `jax.numpy.transpose`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.transpose.html

    Input ports:
        (0) The input matrix.

    Output ports:
        (0) The transpose of the input matrix.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(cnp.transpose, *args, **kwargs)

MinMax

Bases: ReduceBlock

Return the extremum of the input signals.

Input ports

(0..n_in-1) The input signals.

Output ports

(0) The minimum or maximum of the input signals.

Parameters:

Name Type Description Default
operator

One of "min" or "max". Determines whether the block returns the minimum or maximum of the input signals.

required
Events

An event is triggered when the extreme input signal changes. For example, if the block is configured as a "max" block with two inputs and the second signal becomes greater than the first, a zero-crossing event will be triggered.

Source code in collimator/library/primitives.py
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
class MinMax(ReduceBlock):
    """Return the extremum of the input signals.

    Input ports:
        (0..n_in-1) The input signals.

    Output ports:
        (0) The minimum or maximum of the input signals.

    Parameters:
        operator:
            One of "min" or "max". Determines whether the block returns the minimum
            or maximum of the input signals.

    Events:
        An event is triggered when the extreme input signal changes.  For example,
        if the block is configured as a "max" block with two inputs and the second
        signal becomes greater than the first, a zero-crossing event will be triggered.
    """

    @parameters(static=["operator"])
    def __init__(self, n_in, operator, **kwargs):
        super().__init__(n_in, None, **kwargs)

    def initialize(self, operator):
        func_lookup = {
            "max": self._max,
            "min": self._min,
        }
        if operator not in func_lookup:
            # cannot pass system=self because this error must be raised BEFORE calling super.__init__()
            # in the case of inheritting from FeedthroughBlock.
            # if we call super.__init__() first, we get missing key error for func_lookup[base].
            raise BlockParameterError(
                message=f"MinMax block {self.name} has invalid selection {operator} for 'operator'. Valid options: "
                + ", ".join([f for f in func_lookup.keys()]),
                parameter_name="operator",
            )

        self.operator = operator

        self.replace_op(func_lookup[operator])

        guard_lookup = {
            "max": self._max_guard,
            "min": self._min_guard,
        }

        self._guard = guard_lookup[operator]

    def _min(self, inputs):
        return cnp.min(cnp.array(inputs))

    def _max(self, inputs):
        return cnp.max(cnp.array(inputs))

    def _min_guard(self, _time, _state, *inputs, **_params):
        return cnp.argmin(cnp.array(inputs)).astype(float)

    def _max_guard(self, _time, _state, *inputs, **_params):
        return cnp.argmax(cnp.array(inputs)).astype(float)

    def initialize_static_data(self, context):
        # Add a zero-crossing event so ODE solvers can't try to integrate
        # through a discontinuity. For efficiency, only do this if the output
        # is fed to an ODE block
        if not self.has_zero_crossing_events and (self.output_ports[0]):
            self.declare_zero_crossing(self._guard, direction="edge_detection")

        return super().initialize_static_data(context)

ModelicaFMU

Bases: LeafSystem

Source code in collimator/library/fmu_import.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class ModelicaFMU(LeafSystem):
    # Should we pass parameter overrides via kwargs? Sounds like it could conflict
    # in some rare cases (eg. dt, name...). The corresponding definition in
    # block_interface.py is pretty fragile in this regard.
    def __init__(
        self,
        file_name,
        dt,
        name=None,
        input_names: list[str] = None,
        output_names: list[str] = None,
        parameters: dict = None,
        start_time: float = 0.0,
        **kwargs,
    ):
        """Load and execute an FMU for Co-Simulation.

        Args:
            file_name (str): path to FMU file
            dt (float): stepsize for FMU simulation
            name (str, optional): name of block
            input_names (list[str], optional): if set, only expose these inputs
            output_names (list[str], optional): if set, only expose these outputs
            parameters (dict, optional): dictionary of parameter overrides
            kwargs: ignored
        """
        try:
            super().__init__(name=name)
            self._init(
                file_name,
                dt,
                name=name or f"fmu_{self.system_id}",
                input_names=input_names,
                output_names=output_names,
                parameters=parameters,
                start_time=start_time,
            )
        except Exception as e:
            logger.error(
                "Failed to initialize FMU block %s (%s): %s", name, self.system_id, e
            )
            raise BlockInitializationError(str(e), system=self)

    @parameters(static=["file_name"])
    def _init(
        self,
        file_name,
        dt,
        name: str,
        input_names: list[str] = None,
        output_names: list[str] = None,
        parameters: dict = None,
        start_time: float = 0.0,
    ):
        self.dt = dt

        # read the model description
        model_description = fmpy.read_model_description(file_name)

        # extract the FMU
        unzipdir = fmpy.extract(file_name)

        self.fmu = fmu = fmi2.FMU2Slave(
            guid=model_description.guid,
            unzipDirectory=unzipdir,
            modelIdentifier=model_description.coSimulation.modelIdentifier,
            instanceName=name,
        )

        # initialize
        fmu.instantiate()
        # setup and set startTime before entering initialization mode per FMI 2.0.4 section 2.1.6.
        fmu.setupExperiment(startTime=start_time)
        # enter initialization mode before get/set params per FMI 2.0.4 section 4.2.4.
        fmu.enterInitializationMode()

        # collect the value references
        self.fmu_inputs: list[ValueReference] = []
        self.fmu_outputs: list[ValueReference] = []

        inputs_by_name: dict[str, ScalarVariable] = {}
        outputs_by_name: dict[str, ScalarVariable] = {}
        variable_by_id: dict[int, ScalarVariable] = {}

        # FIXME: we rely on the XML file here, but collimator uses a similar
        # JSON file with altered variable names.
        # TODO: implement support for parsing that file and mapping from
        # collimator json name to/from xml name properly.
        def _compatible_param_name(name):
            return name.replace(".", "_")

        for variable in model_description.modelVariables:
            if variable.causality == "input":
                variable_by_id[variable.valueReference] = variable
                inputs_by_name[variable.name] = variable
            elif variable.causality == "output":
                variable_by_id[variable.valueReference] = variable
                outputs_by_name[variable.name] = variable
            elif variable.causality == "parameter" and parameters is not None:
                compat_name = _compatible_param_name(variable.name)
                parameter_value = parameters.get(compat_name, None)
                if parameter_value is None:
                    continue

                logger.debug(
                    "Setting parameter #%d '%s' <%s>: %s %s",
                    variable.valueReference,
                    variable.name,
                    variable.type,
                    parameter_value,
                    type(parameter_value),
                )

                # Values at this point have been wrapped into np.ndarray of
                # shape () via wildcat's JSON parsing. Enumerations are ints.
                match variable.type:
                    case "Boolean":
                        parameter_value = bool(parameter_value)
                        fmu.setBoolean([variable.valueReference], [parameter_value])
                    case "Integer":
                        parameter_value = int(parameter_value)
                        fmu.setInteger([variable.valueReference], [parameter_value])
                    case "Real":
                        parameter_value = float(parameter_value)
                        fmu.setReal([variable.valueReference], [parameter_value])
                    case "String":
                        parameter_value = str(parameter_value)
                        fmu.setString([variable.valueReference], [parameter_value])
                    case "Enumeration":
                        parameter_value = int(parameter_value)
                        fmu.setInteger([variable.valueReference], [parameter_value])
                    case _:
                        # not implemented
                        raise BlockInitializationError(
                            f"Unsupported type for parameter {variable.name} in "
                            + f"FMU block {name}: {variable.type}",
                            system=self,
                        )

        # If input_names or output_names are set, we filter out the variables
        # exposed as I/O ports to match those. This so that the ports in model.json
        # actually match those in the FMU.
        # NOTE: Maybe this is unnecessarily complicated.
        if input_names is not None:
            for name in input_names:
                if name not in inputs_by_name:
                    raise BlockInitializationError(
                        f"Input port {name} found on the block { name} "
                        + f"but not found in FMU {file_name}",
                        system=self,
                    )
                variable = inputs_by_name[name]
                self.fmu_inputs.append(variable.valueReference)
                self.declare_input_port(name=variable.name)
        else:
            for name, variable in inputs_by_name.items():
                self.fmu_inputs.append(variable.valueReference)
                self.declare_input_port(name=name)

        if output_names is not None:
            for name in output_names:
                if name not in outputs_by_name:
                    raise BlockInitializationError(
                        f"Input port {name} found on the block { name} "
                        + f"but not found in FMU {file_name}",
                        system=self,
                    )
                variable = outputs_by_name[name]
                self.fmu_outputs.append(variable.valueReference)
        else:
            for name, variable in outputs_by_name.items():
                self.fmu_outputs.append(variable.valueReference)

        # exit initialization mode after get/set params per FMI 2.0.4 section 4.2.4.
        fmu.exitInitializationMode()

        # Declare a discrete state component for each of the output variables
        self._create_discrete_state_type(fmu, self.fmu_outputs, variable_by_id)

        # Create the default discrete state values
        default_values = {}

        for output_ref in self.fmu_outputs:
            variable = variable_by_id[output_ref]
            match variable.type:
                case "Boolean":
                    start_value = fmu.getBoolean([variable.valueReference])[0]
                case "Integer" | "Enumeration":
                    start_value = fmu.getInteger([variable.valueReference])[0]
                case "Real":
                    start_value = fmu.getReal([variable.valueReference])[0]
                case _:
                    raise NotImplementedError(
                        f"Unsupported type for output port {variable.name} in FMU: {variable.type}"
                    )
            default_values[variable.name] = start_value

        # Map the default values to array-like types so that they have shape and dtype
        default_state = jax.tree_util.tree_map(
            cnp.asarray, self.DiscreteStateType(**default_values)
        )
        self.declare_discrete_state(default_value=default_state, as_array=False)

        # Declare an output port for each of the output variables
        def _make_output_callback(o_port_name):
            def _output(time, state, *inputs, **parameters):
                return getattr(state.discrete_state, o_port_name)

            return _output

        for o_port_name in default_values:
            self.declare_output_port(
                _make_output_callback(o_port_name),
                name=o_port_name,
                prerequisites_of_calc=[DependencyTicket.xd],
                requires_inputs=False,
            )

        # The step function acts as a periodic update that will update all components
        # of the discrete state.
        def _step(time, state, *inputs):
            args = (time, state, *inputs)
            # Use the io_callback so that we can call the untraceable FMU object
            return io_callback(self.exec_step, default_state, *args)

        self.declare_periodic_update(
            _step,
            period=dt,
            offset=dt,
        )

    def _create_discrete_state_type(self, fmu, fmu_outputs, variables):
        self.state_names = [variables[output_ref].name for output_ref in fmu_outputs]
        self.DiscreteStateType = namedtuple("DiscreteState", self.state_names)

    def exec_step(self, time, state, *inputs, **parameters):
        # NOTE: We should get the fmu from the context in order to build a pure
        # function but it is very unlikely this would ever work with FMUs since
        # they have their own internal hidden state. More context here:
        # https://github.com/collimator-ai/collimator/pull/5330/files#r1419062533
        # Also look at that PR to see the previous implementation (it worked with
        # a single I/O port).

        try:
            fmu = self.fmu

            # Note: although it may appear that the order of operations below is
            # backwards, e.g. 1] get_outputs, 2] set_inputs, 3] step, this is
            # actually intentional.
            # Explanation by example assuming 1sec update intervals.
            # The reason get_outputs happens before set_inputs and 'step, is that
            # at t=0, the fmu outputs are already at t=0, so we can just read them.
            # Then, the fmu should get inputs at t=0, and use those to take a step
            # to t=1. The step operation, using inputs at t=0, puts the fmu in a
            # state where it outputs are now at t=1. This we cannot read them until
            # next update interval at t=1.

            # Retrieve the outputs
            fmu_out = fmu.getReal(self.fmu_outputs)
            # Match the outputs with their names in the discrete state
            xd = {name: value for name, value in zip(self.state_names, fmu_out)}

            # Set inputs
            fmu.setReal(self.fmu_inputs, list(inputs))
            # Advance the FMU in time
            fmu.doStep(currentCommunicationPoint=time, communicationStepSize=self.dt)

        except fmi2.FMICallException as e:
            logger.error(
                "Failed to run FMU block %s (%s): %s", self.name, self.system_id, e
            )
            raise BlockRuntimeError(str(e), system=self) from e

        xd = jax.tree_util.tree_map(cnp.asarray, xd)

        return self.DiscreteStateType(**xd)

__init__(file_name, dt, name=None, input_names=None, output_names=None, parameters=None, start_time=0.0, **kwargs)

Load and execute an FMU for Co-Simulation.

Parameters:

Name Type Description Default
file_name str

path to FMU file

required
dt float

stepsize for FMU simulation

required
name str

name of block

None
input_names list[str]

if set, only expose these inputs

None
output_names list[str]

if set, only expose these outputs

None
parameters dict

dictionary of parameter overrides

None
kwargs

ignored

{}
Source code in collimator/library/fmu_import.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    file_name,
    dt,
    name=None,
    input_names: list[str] = None,
    output_names: list[str] = None,
    parameters: dict = None,
    start_time: float = 0.0,
    **kwargs,
):
    """Load and execute an FMU for Co-Simulation.

    Args:
        file_name (str): path to FMU file
        dt (float): stepsize for FMU simulation
        name (str, optional): name of block
        input_names (list[str], optional): if set, only expose these inputs
        output_names (list[str], optional): if set, only expose these outputs
        parameters (dict, optional): dictionary of parameter overrides
        kwargs: ignored
    """
    try:
        super().__init__(name=name)
        self._init(
            file_name,
            dt,
            name=name or f"fmu_{self.system_id}",
            input_names=input_names,
            output_names=output_names,
            parameters=parameters,
            start_time=start_time,
        )
    except Exception as e:
        logger.error(
            "Failed to initialize FMU block %s (%s): %s", name, self.system_id, e
        )
        raise BlockInitializationError(str(e), system=self)

MuJoCo

Bases: MuJoCoBase

MuJoCo implementation without MJX.

Refer to MJX for the main docs.

Unlike the MJX variant of the block, this version uses the solver provided by mujoco itself and the physics are fully handled by mujoco. This behaves like a Co-Simulation environment.

This variant may be used to speed up compilation times or in situations where full JAX is not available or practical.

Source code in collimator/library/mujoco.py
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
class MuJoCo(MuJoCoBase):
    """MuJoCo implementation without MJX.

    Refer to MJX for the main docs.

    Unlike the MJX variant of the block, this version uses the solver provided
    by mujoco itself and the physics are fully handled by mujoco. This behaves
    like a Co-Simulation environment.

    This variant may be used to speed up compilation times or in situations where
    full JAX is not available or practical.
    """

    def __init__(
        self,
        file_name: str,
        dt: float = 0.01,
        key_frame_0: int | str = None,
        qpos_0: Array = None,
        qvel_0: Array = None,
        act_0: Array = None,
        enable_sensor_data=False,
        enable_video_output=False,
        video_size: tuple[int, int] = None,
        enable_mocap_pos=False,
        custom_output_scripts: dict[str, str] = None,
        vHIL=False,
        vHIL_dt=0.01,
        **kwargs,
    ):
        super().__init__(
            use_mjx=False,
            file_name=file_name,
            dt=dt,
            key_frame_0=key_frame_0,
            qpos_0=qpos_0,
            qvel_0=qvel_0,
            act_0=act_0,
            enable_sensor_data=enable_sensor_data,
            enable_video_output=enable_video_output,
            video_size=video_size,
            enable_mocap_pos=enable_mocap_pos,
            custom_output_scripts=custom_output_scripts,
            vHIL=vHIL,
            vHIL_dt=vHIL_dt,
            **kwargs,
        )

        # This output cb implements the call to _step and is the reference callback
        # that all other outputs will depend on.
        def _qpos_cb(time, state, *inputs, **parameters):
            def cb(inputs):
                self._data.ctrl = inputs[0]
                if enable_mocap_pos:
                    self._data.mocap_pos[:] = inputs[1]
                mujoco.mj_step(self._model, self._data)
                qpos_normalized_quats = self.normalize_qpos_quat(self._qpos())
                return qpos_normalized_quats

            return io_callback(cb, self.qpos_0, inputs)

        self._step_cache_index = self.declare_output_port(
            _qpos_cb,
            default_value=self.qpos_0,
            requires_inputs=True,
            offset=dt,
            period=dt,
            name="qpos",
        )

        def _qvel_cb(time, state, *inputs, **parameters):
            return io_callback(self._qvel, self.qvel_0)

        self.declare_output_port(
            _qvel_cb,
            default_value=self.qvel_0,
            requires_inputs=True,
            offset=dt,
            period=dt,
            name="qvel",
            prerequisites_of_calc=[self._step_cache_index],
        )

        def _act_cb(time, state, *inputs, **parameters):
            return io_callback(self._act, self.act_0)

        self.declare_output_port(
            _act_cb,
            default_value=self.act_0,
            requires_inputs=True,
            offset=dt,
            period=dt,
            name="act",
            prerequisites_of_calc=[self._step_cache_index],
        )

        if enable_sensor_data:
            self._declare_sensor_data_port(dt)
        if enable_video_output:
            self._declare_video_output_port(video_size)
        if vHIL:
            self._declare_vhil_fake_output_port(vHIL_dt)
        self._declare_custom_output_ports(
            custom_output_scripts, dt, requires_inputs=False
        )

    # def post_simulation_finalize(self) -> None:
    #     # FIXME this should not be here but I had a "too many files opened" error
    #     self._model = None
    #     self._data = None
    #     return super().post_simulation_finalize()

    def _qpos(self, state=None):
        return self._data.qpos

    def _qvel(self, state=None):
        return self._data.qvel

    def _act(self, state=None):
        return self._data.act

    def _sensordata(self):
        return self._data.sensordata

    def normalize_qpos_quat(self, qpos):
        mujoco.mj_normalizeQuat(self._model, qpos)
        return qpos

    def render(self, data=None):
        if data is None:
            data = self._data
        return super().render(data)

    def _output_video(self, time, state, *inputs, **parameters):
        def cb(time):
            return self.render(self._data)

        return io_callback(
            cb,
            self._video_default,
            time,
        )

Multiplexer

Bases: ReduceBlock

Stack the input signals into a single output signal.

Dispatches to jax.numpy.hstack, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.hstack.html

Input ports

(0..n_in-1) The input signals.

Output ports

(0) The stacked output signal.

Source code in collimator/library/primitives.py
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
class Multiplexer(ReduceBlock):
    """Stack the input signals into a single output signal.

    Dispatches to `jax.numpy.hstack`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.hstack.html

    Input ports:
        (0..n_in-1) The input signals.

    Output ports:
        (0) The stacked output signal.
    """

    def __init__(self, n_in, *args, **kwargs):
        super().__init__(n_in, cnp.hstack, *args, **kwargs)

Offset

Bases: FeedthroughBlock

Add a constant offset or bias to the input signal.

Given an input signal u and offset value b, this will return y = u + b.

Input ports

(0) The input signal.

Output ports

(0) The input signal plus the offset.

Parameters:

Name Type Description Default
offset

The constant offset to add to the input signal.

required
Source code in collimator/library/primitives.py
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
class Offset(FeedthroughBlock):
    """Add a constant offset or bias to the input signal.

    Given an input signal `u` and offset value `b`, this will return `y = u + b`.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The input signal plus the offset.

    Parameters:
        offset:
            The constant offset to add to the input signal.
    """

    @parameters(dynamic=["offset"])
    def __init__(self, offset, *args, **kwargs):
        super().__init__(lambda x, offset: x + offset, *args, **kwargs)

    def initialize(self, offset):
        pass

PID

Bases: LTISystem

Continuous-time PID controller.

The PID controller is implemented as a state-space system with matrices (A, B, C, D), which are then used to create a (second-order) LTISystem. Note that this only supports single-input, single-output PID controllers.

The PID controller implements the following control law:

    u = kp * e + ki * ∫e + kd * ė

where e is the error signal, and ∫e and ė are the integral and derivative of the error signal, respectively.

With a filter coefficient of n (to make the transfer function proper), the state-space form of the system is:

A = [[0, 1], [0, -n]]
B = [[0], [1]]
C = [[ki * n, (kp * n + ki) - (kp + kd * n) * n]]
D = [[kp + kd * n]]

Since D is nonzero, the block is feedthrough.

Input ports

(0) e: Error signal (scalar)

Output ports

(0) u: Control signal (scalar)

Parameters:

Name Type Description Default
kp

Proportional gain

required
ki

Integral gain

required
kd

Derivative gain

required
n

Derivative filter coefficient

required
initial_state

Initial state of the integral term (default: 0)

0.0
Source code in collimator/library/linear_system.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class PID(LTISystem):
    """Continuous-time PID controller.

    The PID controller is implemented as a state-space system with matrices
    (A, B, C, D), which are then used to create a (second-order) LTISystem.
    Note that this only supports single-input, single-output PID controllers.

    The PID controller implements the following control law:
    ```
        u = kp * e + ki * ∫e + kd * ė
    ```
    where e is the error signal, and ∫e and ė are the integral and derivative
    of the error signal, respectively.

    With a filter coefficient of `n` (to make the transfer function proper), the
    state-space form of the system is:
    ```
    A = [[0, 1], [0, -n]]
    B = [[0], [1]]
    C = [[ki * n, (kp * n + ki) - (kp + kd * n) * n]]
    D = [[kp + kd * n]]
    ```

    Since D is nonzero, the block is feedthrough.

    Input ports:
        (0) e: Error signal (scalar)

    Output ports:
        (0) u: Control signal (scalar)

    Parameters:
        kp: Proportional gain
        ki: Integral gain
        kd: Derivative gain
        n: Derivative filter coefficient
        initial_state: Initial state of the integral term (default: 0)
    """

    @parameters(
        dynamic=["kp", "ki", "kd", "n"],
        static=["initial_state"],
    )
    def __init__(
        self,
        kp,
        ki,
        kd,
        n,
        initial_state=0.0,
        enable_external_initial_state=False,
        **kwargs,
    ):
        if enable_external_initial_state:
            raise NotImplementedError(
                "External initial state not yet implemented for PID"
            )

        A, B, C, D = self._get_abcd(kp, ki, kd, n)
        initialize_states = cnp.array([initial_state, 0.0])
        super().__init__(A, B, C, D, initialize_states=initialize_states, **kwargs)

    def _get_abcd(self, kp, ki, kd, n):
        A = cnp.array([[0.0, 1.0], [0.0, -n]])
        B = cnp.array([[0.0], [1.0]])
        C = cnp.array([(ki * n), ((kp * n + ki) - (kp + kd * n) * n)])
        D = cnp.array([(kp + kd * n)])
        return A, B, C, D

    def _eval_output(self, time, state, *inputs, **params):
        kp, ki, kd, n = params["kp"], params["ki"], params["kd"], params["n"]

        A, B, C, D = self._get_abcd(kp, ki, kd, n)
        (self.A, self.B, self.C, self.D, self.n, self.m, self.p) = _reshape(A, B, C, D)

        return self._eval_output_base(self.C, self.D, state, *inputs)

    def ode(self, time, state, u, **params):
        kp, ki, kd, n = params["kp"], params["ki"], params["kd"], params["n"]

        A, B, C, D = self._get_abcd(kp, ki, kd, n)
        (self.A, self.B, self.C, self.D, self.n, self.m, self.p) = _reshape(A, B, C, D)

        return super().ode(time, state, u, A=self.A, B=self.B)

    def initialize(self, kp, ki, kd, n, initial_state, **kwargs):
        A, B, C, D = self._get_abcd(kp, ki, kd, n)
        initialize_states = cnp.array([initial_state, 0.0])
        self._init_state(A, B, C, D, initialize_states)

PIDDiscrete

Bases: LeafSystem

Discrete-time PID controller.

This block implements a discrete-time PID controller with a first-order approximation to the integrated error and an optional derivative filter. The integrated error term is computed as:

    e_int[k+1] = e_int[k] + e[k] * dt

where e is the error signal and dt is the sampling period. The derivative term is computed in the same way as for the DerivativeDiscrete block, including filter options described there. With the running error integral e_int and current estimate of the time derivative of the error e_dot, the output is:

    u[k] = kp * e[k] + ki * e_int[k] + kd * e_dot[k]
Input ports

(0) The error signal.

Output ports

(0) The control signal computed by the PID algorithm.

Parameters:

Name Type Description Default
kp

The proportional gain (scalar)

1.0
ki

The integral gain (scalar)

1.0
kd

The derivative gain (scalar)

1.0
dt

The sampling period of the block.

required
initial_state

The initial value of the running error integral. Default is 0.

0.0
enable_external_initial_state

Source for the value used for the integrator initial state. True=from inport, False=from the initial_state parameter.

False
filter_type

One of "none", "forward", "backward", or "bilinear". Determines the type of filter used to estimate the derivative of the error signal. Default is "none". See DerivativeDiscrete documentation for details.

'none'
filter_coefficient

The filter coefficient for the derivative filter. Default is 1.0. See DerivativeDiscrete documentation for details.

1.0
Source code in collimator/library/primitives.py
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
class PIDDiscrete(LeafSystem):
    """Discrete-time PID controller.

    This block implements a discrete-time PID controller with a first-order
    approximation to the integrated error and an optional derivative filter.
    The integrated error term is computed as:
    ```
        e_int[k+1] = e_int[k] + e[k] * dt
    ```
    where `e` is the error signal and `dt` is the sampling period.  The derivative
    term is computed in the same way as for the DerivativeDiscrete block, including
    filter options described there.  With the running error integral `e_int` and
    current estimate of the time derivative of the error `e_dot`, the output is:
    ```
        u[k] = kp * e[k] + ki * e_int[k] + kd * e_dot[k]
    ```

    Input ports:
        (0) The error signal.

    Output ports:
        (0) The control signal computed by the PID algorithm.

    Parameters:
        kp:
            The proportional gain (scalar)
        ki:
            The integral gain (scalar)
        kd:
            The derivative gain (scalar)
        dt:
            The sampling period of the block.
        initial_state:
            The initial value of the running error integral.  Default is 0.
        enable_external_initial_state:
            Source for the value used for the integrator initial state. True=from inport,
            False=from the initial_state parameter.
        filter_type:
            One of "none", "forward", "backward", or "bilinear".  Determines the type of
            filter used to estimate the derivative of the error signal.  Default is
            "none".  See DerivativeDiscrete documentation for details.
        filter_coefficient:
            The filter coefficient for the derivative filter.  Default is 1.0.  See
            DerivativeDiscrete documentation for details.
    """

    class DiscreteStateType(NamedTuple):
        integral: Array
        # Recursive filter memory for the derivative estimate
        e_prev: Array
        e_dot_prev: Array

    @parameters(
        static=["filter_type", "filter_coefficient"],
        dynamic=["kp", "ki", "kd", "initial_state"],
    )
    def __init__(
        self,
        dt,
        kp=1.0,
        ki=1.0,
        kd=1.0,
        initial_state=0.0,
        enable_external_initial_state=False,
        filter_type="none",
        filter_coefficient=1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dt = dt
        self.input_index = self.declare_input_port()

        self.enable_external_initial_state = enable_external_initial_state
        self.initial_state_index = None
        if enable_external_initial_state:
            self.initial_state_index = self.declare_input_port()

        # Declare the periodic update
        self._periodic_update_idx = self.declare_periodic_update()

        # Declare an output port for the control signal
        self.control_output = self.declare_output_port()

        # NOTE:
        # An extra output port for the derivative value is not strictly necessary,
        # but the filtered estimate could be resused elsewhere.  Also, having the
        # previous value saved in the discrete output component of state would allows
        # it to be reused in the recursive filter without recomputing it as part of
        # the update step, a minor efficiency gain.  The tradeoff is an extra event
        # that has to be handled.  This implementation uses one output event and
        # re-does the derivative calculation when a recursive filter is used, but
        # we could always do it the other way in the future.

    def initialize(
        self,
        kp,
        ki,
        kd,
        initial_state,
        filter_type,
        filter_coefficient,
    ):
        # Declare an internal discrete state
        self.declare_discrete_state(
            default_value=self.DiscreteStateType(
                integral=initial_state,
                e_prev=0.0,
                e_dot_prev=0.0,
            ),
            as_array=False,
        )

        self.configure_periodic_update(
            self._periodic_update_idx,
            self._update,
            period=self.dt,
            offset=0.0,
        )

        # Determine the coefficients of the filter, if applicable
        # The filter is a pair of two-element array and the filter
        # equation is:
        # a0*y[k] + a1*y[k-1] = b0*u[k] + b1*u[k-1]
        self.filter_type = filter_type
        self.filter = derivative_filter(
            N=filter_coefficient, dt=self.dt, filter_type=filter_type
        )

        self.configure_output_port(
            self.control_output,
            self._output,
            period=self.dt,
            offset=0.0,
            default_value=initial_state,
            prerequisites_of_calc=[DependencyTicket.xd, self.input_ports[0].ticket],
        )

    def reset_default_values(self, **dynamic_parameters):
        self.configure_discrete_state_default_value(
            self.DiscreteStateType(
                integral=dynamic_parameters["initial_state"],
                e_prev=0.0,
                e_dot_prev=0.0,
            ),
            as_array=False,
        )
        self.configure_output_port_default_value(
            self.control_output, dynamic_parameters["initial_state"]
        )

    def _eval_derivative(self, _time, state, *inputs, **_params):
        # Filtered derivative estimate

        e = inputs[self.input_index]  # Error signal from upstream
        e_prev = state.discrete_state.e_prev
        b, a = self.filter  # IIR filter coefficients

        # If the filter is recursive we need to reuse the previous derivative
        # estimate.
        if self.filter_type != "none":
            # Filtered estimate of the time derivative
            e_dot_prev = state.discrete_state.e_dot_prev

            # New estimate of the time derivative of the error signal
            e_dot = (b[0] * e + b[1] * e_prev - a[1] * e_dot_prev) / a[0]

        else:
            # Standard finite difference approximation - no recursion
            e_dot = (b[0] * e + b[1] * e_prev) / a[0]

        return e_dot

    def _update(self, time, state, *inputs, **params):
        e = inputs[self.input_index]  # Error signal from upstream

        # Integrated error signal
        e_int = state.discrete_state.integral

        # Update the derivative estimate if needed for a recursive filter.
        if self.filter_type != "none":
            e_dot = self._eval_derivative(time, state, *inputs, **params)
        else:
            # This state entry isn't used for the finite difference estimator.
            # Can just keep the original value as a placeholder.
            e_dot = state.discrete_state.e_dot_prev

        # Update the internal state
        return self.DiscreteStateType(
            integral=e_int + e * self.dt, e_prev=e, e_dot_prev=e_dot
        )

    def _eval_control(self, e, e_int, e_dot, **params):
        # Calculate the control signal for the PID control law
        kp, ki, kd = params["kp"], params["ki"], params["kd"]
        u = kp * e + ki * e_int + kd * e_dot
        return u

    def _output(self, time, state, *inputs, **params):
        e = inputs[self.input_index]  # Error signal from upstream
        e_int = state.discrete_state.integral
        e_dot = self._eval_derivative(time, state, *inputs, **params)
        return self._eval_control(e, e_int, e_dot, **params)

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        u = self.eval_input(context)
        xd = context[self.system_id].discrete_state.integral
        check_state_type(
            self,
            inp_data=u,
            state_data=xd,
            error_collector=error_collector,
        )

    def initialize_static_data(self, context):
        """Set the initial state from the input port, if specified via config"""
        if self.initial_state_index is not None:
            try:
                initial_state = self.eval_input(context, self.initial_state_index)
                default_value = self.DiscreteStateType(
                    integral=initial_state,
                    e_prev=0.0,
                    e_dot_prev=0.0,
                )
                self._default_discrete_state = default_value
                local_context = context[self.system_id].with_discrete_state(
                    default_value
                )
                context = context.with_subcontext(self.system_id, local_context)

            except UpstreamEvalError:
                # The diagram has only been partially created.  Defer the
                # inference of the initial state until the upstream block has been
                # connected.
                logger.debug(
                    "PID_Discrete.initialize_static_data: UpstreamEvalError. "
                    "Continuing without default value initialization."
                )
        return super().initialize_static_data(context)

initialize_static_data(context)

Set the initial state from the input port, if specified via config

Source code in collimator/library/primitives.py
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
def initialize_static_data(self, context):
    """Set the initial state from the input port, if specified via config"""
    if self.initial_state_index is not None:
        try:
            initial_state = self.eval_input(context, self.initial_state_index)
            default_value = self.DiscreteStateType(
                integral=initial_state,
                e_prev=0.0,
                e_dot_prev=0.0,
            )
            self._default_discrete_state = default_value
            local_context = context[self.system_id].with_discrete_state(
                default_value
            )
            context = context.with_subcontext(self.system_id, local_context)

        except UpstreamEvalError:
            # The diagram has only been partially created.  Defer the
            # inference of the initial state until the upstream block has been
            # connected.
            logger.debug(
                "PID_Discrete.initialize_static_data: UpstreamEvalError. "
                "Continuing without default value initialization."
            )
    return super().initialize_static_data(context)

Power

Bases: FeedthroughBlock

Raise the input signal to a constant power.

Dispatches to jax.numpy.power, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.power.html

For input signal u with exponent p, the output will be y = u ** p.

Input ports

(0) The input signal.

Output ports

(0) The input signal raised to the power of the exponent.

Parameters:

Name Type Description Default
exponent

The exponent to which the input signal is raised.

required
Source code in collimator/library/primitives.py
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
class Power(FeedthroughBlock):
    """Raise the input signal to a constant power.

    Dispatches to `jax.numpy.power`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.power.html

    For input signal `u` with exponent `p`, the output will be `y = u ** p`.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The input signal raised to the power of the exponent.

    Parameters:
        exponent:
            The exponent to which the input signal is raised.
    """

    @parameters(static=["exponent"])
    def __init__(self, exponent, **kwargs):
        super().__init__(self._func, **kwargs)

        # Note that the exponent here is declared as a configuration
        # parameter and not a context parameter, making it non-differentiable.
        # This is because the derivative rule for the exponent includes a log
        # of the primal input signal, which can cause NaN values during backprop
        # if the input signal is non-positive. Specifically, for `y = u ** p`, the
        # linearization with respect to `p` is `dy = y * log(u) * dp`. If we
        # eventually want to support backprop through this block, we will need
        # to handle the log of the input signal in a way that avoids NaN values.
        # (e.g. with gradient clipping). Tracked in WC-306
        self.exponent = exponent

    def initialize(self, exponent):
        self.exponent = exponent

    def _func(self, *inputs, **parameters):
        (u,) = inputs
        return u**self.exponent

Product

Bases: ReduceBlock

Compute the product and/or quotient of the input signals.

The block will multiply or divide the input signals, depending on the specified operators. For example, if the block has three inputs u1, u2, and u3 and is configured with operators="**/", then the output signal will be y = u1 * u2 / u3. By default, the block will multiply all of the input signals.

Input ports

(0..n_in-1) The input signals.

Output ports

(0) The product and/or quotient of the input signals.

Parameters:

Name Type Description Default
n_in

The number of input ports.

required
operators

A string of length n_in specifying the operators to apply to each of the input signals. Each character in the string must be either "" or "/". The default is "".

None
denominator_limit

Currently unsupported

None
divide_by_zero_behavior

Currently unsupported

None
Source code in collimator/library/primitives.py
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
class Product(ReduceBlock):
    """Compute the product and/or quotient of the input signals.

    The block will multiply or divide the input signals, depending on the specified
    operators.  For example, if the block has three inputs `u1`, `u2`, and `u3` and
    is configured with operators="**/", then the output signal will be
    `y = u1 * u2 / u3`.  By default, the block will multiply all of the input signals.

    Input ports:
        (0..n_in-1) The input signals.

    Output ports:
        (0) The product and/or quotient of the input signals.

    Parameters:
        n_in:
            The number of input ports.
        operators:
            A string of length `n_in` specifying the operators to apply to each of
            the input signals.  Each character in the string must be either "*" or "/".
            The default is "*".
        denominator_limit:
            Currently unsupported
        divide_by_zero_behavior:
            Currently unsupported
    """

    @parameters(static=["operators", "denominator_limit", "divide_by_zero_behavior"])
    def __init__(
        self,
        n_in,
        operators=None,  # Expect "**/*", etc
        denominator_limit=None,
        divide_by_zero_behavior=None,
        **kwargs,
    ):
        super().__init__(n_in, None, **kwargs)

    def initialize(
        self,
        operators=None,  # Expect "**/*", etc
        denominator_limit=None,
        divide_by_zero_behavior=None,
    ):
        if operators is not None and any(char not in {"*", "/"} for char in operators):
            raise BlockParameterError(
                message=f"Product block {self.name} has invalid operators {operators}. Can only contain '*' and '/'",
                system=self,
                parameter_name="operators",
            )

        if operators is not None and "/" in operators:
            num_indices = cnp.array(
                [idx for idx, op in enumerate(operators) if op == "*"]
            )
            den_indices = cnp.array(
                [idx for idx, op in enumerate(operators) if op == "/"]
            )

            def _func(inputs):
                ain = cnp.array(inputs)
                num = cnp.take(ain, num_indices, axis=0)
                den = cnp.take(ain, den_indices, axis=0)
                return cnp.prod(num, axis=0) / cnp.prod(den, axis=0)

        else:

            def _func(inputs):
                return cnp.prod(cnp.array(inputs), axis=0)

        self.replace_op(_func)

ProductOfElements

Bases: FeedthroughBlock

Compute the product of the elements of the input signal.

Dispatches to jax.numpy.prod, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.prod.html

Input ports

(0) The input signal.

Output ports

(0) The product of the elements of the input signal.

Source code in collimator/library/primitives.py
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
class ProductOfElements(FeedthroughBlock):
    """Compute the product of the elements of the input signal.

    Dispatches to `jax.numpy.prod`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.prod.html

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The product of the elements of the input signal.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(cnp.prod, *args, **kwargs)

Pulse

Bases: SourceBlock

A periodic pulse signal.

Given amplitude a, pulse width w, and period p, the output signal is:

    y(t) = a if t % p < w else 0

where % is the modulo operator.

Input ports

None

Output ports

(0) The pulse signal.

Parameters:

Name Type Description Default
amplitude

The amplitude of the pulse signal.

1.0
pulse_width

The fraction of the period during which the pulse is "high".

0.5
period

The period of the pulse signal.

1.0
phase_delay

Currently unsupported.

0.0
Source code in collimator/library/primitives.py
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
class Pulse(SourceBlock):
    """A periodic pulse signal.

    Given amplitude `a`, pulse width `w`, and period `p`, the output signal is:
    ```
        y(t) = a if t % p < w else 0
    ```
    where `%` is the modulo operator.

    Input ports:
        None

    Output ports:
        (0) The pulse signal.

    Parameters:
        amplitude:
            The amplitude of the pulse signal.
        pulse_width:
            The fraction of the period during which the pulse is "high".
        period:
            The period of the pulse signal.
        phase_delay:
            Currently unsupported.
    """

    @parameters(dynamic=["amplitude", "pulse_width", "period", "phase_delay"])
    def __init__(
        self, amplitude=1.0, pulse_width=0.5, period=1.0, phase_delay=0.0, **kwargs
    ):
        super().__init__(self._func, **kwargs)

        # Initialize the floating-point tolerance.  This will be machine epsilon
        # for the floating point type of the time variable (determined in the
        # static initialization step).
        self.eps = 0.0

        if abs(phase_delay) > 1e-9:
            warnings.warn("Warning. Pulse block phase_delay not implemented.")

        # Add a dummy event so that the ODE solver doesn't try to integrate through
        # the discontinuity.
        # ad 2 events, one for the up jump, and one the down jump
        self.declare_discrete_state(default_value=False)
        self._dummy_periodic_update_idx = self.declare_periodic_update()
        self._periodic_update_idx = self.declare_periodic_update()

    def initialize(self, amplitude, pulse_width, period, phase_delay):
        if abs(phase_delay) > 1e-9:
            warnings.warn("Warning. Pulse block phase_delay not implemented.")

        self.configure_periodic_update(
            self._dummy_periodic_update_idx,
            lambda *args, **kwargs: True,
            period=period,
            offset=period,
        )

        self.configure_periodic_update(
            self._periodic_update_idx,
            lambda *args, **kwargs: True,
            period=period,
            offset=period + period * pulse_width,
        )

    def _func(self, time, **parameters):
        # Add a floating-point tolerance to the modulo operation to avoid
        # accuracy issues when the time is an "exact" multiple of the period.
        period_fraction = (
            cnp.remainder(time + self.eps, parameters["period"]) / parameters["period"]
        )
        return cnp.where(
            period_fraction >= parameters["pulse_width"],
            0.0,
            parameters["amplitude"],
        )

    def initialize_static_data(self, context):
        # Determine machine epsilon for the type of the time variable
        self.eps = 2 * cnp.finfo(cnp.result_type(context.time)).eps
        return super().initialize_static_data(context)

PyTorch

Bases: LeafSystem

Block to perform inference with a pre-trained PyTorch model saved as TorchScript.

The input to the block should be of compatible type and shape expected by the TorchScript. For example, if the TorchScript model expects a torch.float32 tensor of shape (3, 224, 224), the input to the block should be a jax.numpy array of shape (3, 224, 224) of dtype jnp.float32.

For output types, if no casting is specified through the cast_outputs_to_dtype parameter, the output of the block will have the same dtype as the TorchScript model output, but expressed as jax.numpy types. For example. if the TorchScript model outputs a torch.float32 tensor, the output of the block will be a jax.numpy array of dtype jnp.float32.

If casting is specified through cast_outputs_to_dtype parameter, all the outputs, of the block will be casted to this specific jax.numpy dtype.

Input ports

(i) The ith input to the model.

Output ports

(j) The jth output of the model.

Parameters:

Name Type Description Default
file_name str

Path to the model Torchscript .pt file.

required
num_inputs int

The number of inputs to the model. Only required for TorchScript models.

1
num_outputs int

The number of outputs of the model.

1
cast_outputs_to_dtype str

The dtype to cast all the outputs of the block to. Must correspond to a jax.numpy datatype. For example, "float32", "float64", "int32", "int64".

None
add_batch_dim_to_inputs bool

Whether to add a new first dimension to the inputs before evaluating the TorchScript or TensorFlow model. This is useful when the model expects a batch dimension.

False
Source code in collimator/library/predictor.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class PyTorch(LeafSystem):
    """
    Block to perform inference with a pre-trained PyTorch model saved as TorchScript.

    The input to the block should be of compatible type and shape expected by
    the TorchScript. For example, if the TorchScript model expects
    a `torch.float32` tensor of shape `(3, 224, 224)`, the input to the block should be
    a `jax.numpy` array of shape (3, 224, 224) of dtype `jnp.float32`.

    For output types, if no casting is specified through the `cast_outputs_to_dtype`
    parameter, the output of the block will have the same dtype as the TorchScript
    model output, but expressed as `jax.numpy` types. For example. if the
    TorchScript model outputs a `torch.float32` tensor, the output of the block will be
    a `jax.numpy` array of dtype `jnp.float32`.

    If casting is specified through `cast_outputs_to_dtype` parameter, all the outputs,
    of the block will be casted to this specific `jax.numpy` dtype.

    Input ports:
        (i) The ith input to the model.

    Output ports:
        (j) The jth output of the model.

    Parameters:
        file_name (str):
            Path to the model Torchscript `.pt` file.

        num_inputs (int):
            The number of inputs to the model. Only required for TorchScript models.

        num_outputs (int):
            The number of outputs of the model.

        cast_outputs_to_dtype (str):
            The dtype to cast all the outputs of the block to. Must correspond to a
            `jax.numpy` datatype. For example, "float32", "float64", "int32", "int64".

        add_batch_dim_to_inputs (bool):
            Whether to add a new first dimension to the inputs before evaluating the
            TorchScript or TensorFlow model. This is useful when the model expects a
            batch dimension.
    """

    EXTRA_FILES = {}

    @parameters(
        static=[
            "file_name",
            "num_inputs",
            "num_outputs",
            "cast_outputs_to_dtype",
            "add_batch_dim_to_inputs",
        ]
    )
    def __init__(
        self,
        file_name,
        num_inputs=1,
        num_outputs=1,
        cast_outputs_to_dtype=None,
        add_batch_dim_to_inputs=False,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self._num_inputs = num_inputs
        self._num_outputs = num_outputs

        for _ in range(num_inputs):
            self.declare_input_port()

        def _make_output_callback(output_index):
            def _output_callback(time, state, *inputs, **params):
                outputs = self._evaluate_output(time, state, *inputs, **params)
                return outputs[output_index]

            return _output_callback

        for output_index in range(num_outputs):
            self.declare_output_port(
                _make_output_callback(output_index),
                requires_inputs=True,
            )

    def initialize(
        self,
        file_name,
        num_inputs=1,
        num_outputs=1,
        cast_outputs_to_dtype=None,
        add_batch_dim_to_inputs=False,
    ):
        if num_inputs != self._num_inputs:
            raise ValueError("num_inputs can't be changed after initialization")
        if num_outputs != self._num_outputs:
            raise ValueError("num_outputs can't be changed after initialization")

        self.dtype_output = (
            getattr(jnp, cast_outputs_to_dtype)
            if cast_outputs_to_dtype is not None
            else None
        )

        self.add_batch_dim_to_inputs = add_batch_dim_to_inputs
        self.model = torch.jit.load(file_name)
        self.model.eval()

    def initialize_static_data(self, context):
        """Infer the output shapes and dtypes of the ML model."""
        # If building as part of a subsystem, this may not be fully connected yet.
        # That's fine, as long as it is connected by root context creation time.
        # This probably isn't a good long-term solution:
        #   see https://collimator.atlassian.net/browse/WC-51
        try:
            inputs = self.collect_inputs(context)
            outputs_jax = self._pure_callback(*inputs)

            self.pure_callback_result_type = [
                jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_jax
            ]
        except UpstreamEvalError:
            logger.debug(
                "PyTorch.initialize_static_data: UpstreamEvalError. "
                "Continuing without default value initialization."
            )
        return super().initialize_static_data(context)

    def _evaluate_output(self, time, state, *inputs, **params):
        return jax.pure_callback(
            self._pure_callback,
            self.pure_callback_result_type,
            *inputs,
        )

    def _pure_callback(self, *inputs):
        inputs_casted = [torch.tensor(np.array(item)) for item in inputs]

        if self.add_batch_dim_to_inputs:
            inputs_casted = [x.unsqueeze(0) for x in inputs_casted]
        outputs = self.model(*inputs_casted)

        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        outputs_jax = (
            [jnp.array(x, self.dtype_output) for x in outputs]
            if self.dtype_output is not None
            else [jnp.array(x) for x in outputs]
        )
        return outputs_jax

initialize_static_data(context)

Infer the output shapes and dtypes of the ML model.

Source code in collimator/library/predictor.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def initialize_static_data(self, context):
    """Infer the output shapes and dtypes of the ML model."""
    # If building as part of a subsystem, this may not be fully connected yet.
    # That's fine, as long as it is connected by root context creation time.
    # This probably isn't a good long-term solution:
    #   see https://collimator.atlassian.net/browse/WC-51
    try:
        inputs = self.collect_inputs(context)
        outputs_jax = self._pure_callback(*inputs)

        self.pure_callback_result_type = [
            jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_jax
        ]
    except UpstreamEvalError:
        logger.debug(
            "PyTorch.initialize_static_data: UpstreamEvalError. "
            "Continuing without default value initialization."
        )
    return super().initialize_static_data(context)

QuadraticCost

Bases: ReduceBlock

LQR-type quadratic cost function for a state and input.

Computes the cost as x'Qx + u'Ru, where Q and R are the cost matrices. In order to compute a running cost, combine this with an Integrator or IntegratorDiscrete block.

Source code in collimator/library/costs_and_losses.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class QuadraticCost(ReduceBlock):
    """LQR-type quadratic cost function for a state and input.

    Computes the cost as x'Qx + u'Ru, where Q and R are the cost matrices.
    In order to compute a running cost, combine this with an `Integrator`
    or `IntegratorDiscrete` block.
    """

    def __init__(self, Q, R, name=None):
        super().__init__(2, self._cost, name=name)
        self.Q = Q
        self.R = R

    def _cost(self, inputs):
        x, u = inputs
        J = jnp.dot(x, jnp.dot(self.Q, x)) + jnp.dot(u, jnp.dot(self.R, u))
        return J.squeeze()

QuanserHAL

Bases: LeafSystem

Hardware Abstraction Layer for Quanser hardware.

This block provides an interface to virtual or physical Quanser hardware. It requires that the Quanser hardware or QLabs simulator be properly configured and that the Quanser python library is available on the system path. See the Quanser documentation for more information.

To use an idealized model of the Qube Servo hardware, see the collimator.library.QubeServoModel block, which may be run without hardware or in the cloud-based simulation UI.

Input ports

(0) Control signal to the motor in volts

Output ports

(0) The observed rotor and pendulum angles in radians

Parameters:

Name Type Description Default
dt

The time step of the simulation.

required
version

The version of the Qube hardware (2 or 3). By default, version 2 is used with hardware=False, or version 3 is used with hardware=True.

2
hardware

If True, connect to the physical hardware. If False, connect to the QLabs simulator.

False
name

The name of the system in the Collimator model.

'QuanserHAL'
Source code in collimator/library/quanser.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class QuanserHAL(LeafSystem):
    """Hardware Abstraction Layer for Quanser hardware.

    This block provides an interface to virtual or physical Quanser hardware.
    It requires that the Quanser hardware or QLabs simulator be properly configured
    and that the Quanser python library is available on the system path.  See the
    Quanser documentation for more information.

    To use an idealized model of the Qube Servo hardware, see the
    `collimator.library.QubeServoModel` block, which may be run without hardware
    or in the cloud-based simulation UI.

    Input ports:
        (0) Control signal to the motor in volts

    Output ports:
        (0) The observed rotor and pendulum angles in radians

    Parameters:
        dt: The time step of the simulation.
        version: The version of the Qube hardware (2 or 3).  By default, version 2 is
            used with `hardware=False`, or version 3 is used with `hardware=True`.
        hardware:
            If True, connect to the physical hardware. If False, connect to the
            QLabs simulator.
        name: The name of the system in the Collimator model.

    """

    def __init__(self, dt, version=2, hardware=False, name="QuanserHAL", ui_id=None):
        super().__init__(name=name, ui_id=ui_id)

        if version is None:
            version = 2 if not hardware else 3

        # Init QLabs
        # HACK for macOS - Quanser's code only works on Windows
        # Insert asyncio.windows_events fake module
        if sys.platform != "win32":
            _windows_events = types.ModuleType("windows_events")
            _windows_events.INFINITE = np.iinfo(np.uint32).max
            sys.modules["asyncio.windows_events"] = _windows_events

        try:
            from pal.products.qube import QubeServo2, QubeServo3
        except Exception:
            raise ImportError(
                "Could not import QubeServo2 or QubeServo3 from pal.products.qube. "
                "Check that the hardware drivers are available on the system path."
            )

        if version not in [2, 3]:
            raise ValueError("version must be 2 or 3")

        if version == 2:
            QubeClass = QubeServo2
        else:
            QubeClass = QubeServo3

        self.qube = QubeClass(hardware=hardware, pendulum=1, frequency=1 / dt)
        self._setup_siginthandler()

        print("Initialized Qube")
        if self.qube.card is None:
            raise RuntimeError(
                "Could not find hardware. Try power-cycling and check connections."
            )
        self.qube.write_led(color=[0, 1, 0])

        self.declare_input_port()  # Inputs are the control signals to the motor

        # Periodically send the control signals to the motor
        self.declare_periodic_update(
            self.step,
            period=dt,
            offset=0.0,
        )

        # Periodically read the sensor outputs
        self.declare_output_port(
            self.output,
            period=dt,
            offset=0.0,
            requires_inputs=False,
        )

    def __exit__(self, exc_type, exc_value, traceback):
        self.terminate()
        super().__exit__(exc_type, exc_value, traceback)

    def _setup_siginthandler(self):
        self.prev_sigint_handler = signal.signal(signal.SIGINT, self._interrupt)

    def _restore_siginthandler(self):
        print("Restoring sigint handler")
        if self.prev_sigint_handler is not None:
            signal.signal(signal.SIGINT, self.prev_sigint_handler)
            self.prev_sigint_handler = None

    # This custom handler makes it possible to Interrupt the jupyter kernel
    # and still connect again to the Qube environment.
    def _interrupt(self, signum, frame):
        prev_handler = self.prev_sigint_handler
        self.terminate()
        if prev_handler is not None:
            prev_handler(signum, frame)

    def terminate(self):
        self._restore_siginthandler()
        if self.qube is not None:
            self.qube.write_led(color=[1, 1, 0])
            self.qube.terminate()
            self.qube = None

    def _impure_step(self, voltage):
        # Write the voltage to the Qube
        self.qube.write_voltage(voltage)

    def step(self, time, state, *inputs, **parameters):
        return io_callback(self._impure_step, None, *inputs)

    def _impure_output(self):
        # Read the sensor outputs
        self.qube.read_outputs()
        theta, alpha = self.qube.motorPosition, self.qube.pendulumPosition
        return jnp.array([theta, alpha])

    def output(self, time, state, *inputs, **parameters):
        return io_callback(self._impure_output, jnp.zeros(2))

Quantizer

Bases: FeedthroughBlock

Discritize the input signal into a set of discrete values.

Given an input signal u and a resolution intervals, this block will quantize the input signal into a set of intervals discrete values. The output signal will be y = intervals * round(u / intervals).

Input ports

(0) The continuous input signal. In most cases, should be scaled to the range [0, intervals].

Output ports

(0) The quantized output signal, on the same scale as the input signal.

Parameters:

Name Type Description Default
interval

The number of discrete values into which the input signal is quantized.

required
Source code in collimator/library/primitives.py
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
class Quantizer(FeedthroughBlock):
    """Discritize the input signal into a set of discrete values.

    Given an input signal `u` and a resolution `intervals`, this block will
    quantize the input signal into a set of `intervals` discrete values.
    The output signal will be `y = intervals * round(u / intervals)`.

    Input ports:
        (0) The continuous input signal. In most cases, should be scaled to the range
            `[0, intervals]`.

    Output ports:
        (0) The quantized output signal, on the same scale as the input signal.

    Parameters:
        interval:
            The number of discrete values into which the input signal is quantized.
    """

    @parameters(dynamic=["interval"])
    def __init__(self, interval, *args, **kwargs):
        super().__init__(
            lambda x, interval: interval * cnp.round(x / interval), *args, **kwargs
        )

    def initialize(self, interval):
        pass

QubeServoModel

Bases: LeafSystem

Plant model for the Quanser Qube Servo Furuta Pendulum.

The Quanser Qube Servo is a pendulum controlled by a rotary arm. The rotary arm is actuated by a DC motor. The pendulum is free to rotate about the rotary arm.

The state of the system is given by the rotor angle (theta), pendulum angle (alpha), rotor angular velocity, and pendulum angular velocity. The input to the system is the voltage applied to the motor, which is converted to torque by a simple linear model.

Input ports

(0) The motor voltage signal

Output ports

(0) If full_state_output is False, the rotor angle and pendulum angle. Otherwise, will return the entire continuous state vector.

Parameters:

Name Type Description Default
x0

Initial state of the system [theta, alpha, theta_dot, alpha_dot]

[0.0, 0.0, 0.0, 0.0]
Rm

Motor resistance (Ohms)

8.4
km

Back-emf constant (V-s/rad)

0.042
mr

Rotary arm mass (kg)

0.095
Lr

Rotor arm length (m)

0.085
br

Rotor arm damping coefficient (N-m-s/rad)

0.0005
mp

Pendulum mass (kg)

0.024
Lp

Pendulum arm length (m)

0.129
bp

Pendulum damping coefficient (N-m-s/rad)

2.5e-05
g

Gravitational constant (m/s^2)

9.81
kr

Feedback control to send the rotor back to zero

0.0
full_state_output

If True, output the full state vector. Otherwise, only output the rotor and pendulum angles.

False
Source code in collimator/library/quanser.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class QubeServoModel(LeafSystem):
    """Plant model for the Quanser Qube Servo Furuta Pendulum.

    The Quanser Qube Servo is a pendulum controlled by a rotary arm. The
    rotary arm is actuated by a DC motor. The pendulum is free to rotate about
    the rotary arm.

    The state of the system is given by the rotor angle (theta), pendulum angle
    (alpha), rotor angular velocity, and pendulum angular velocity. The input to
    the system is the voltage applied to the motor, which is converted to torque
    by a simple linear model.

    Input ports:
        (0) The motor voltage signal

    Output ports:
        (0) If `full_state_output` is False, the rotor angle and pendulum angle.
            Otherwise, will return the entire continuous state vector.

    Parameters:
        x0: Initial state of the system [theta, alpha, theta_dot, alpha_dot]
        Rm: Motor resistance (Ohms)
        km: Back-emf constant (V-s/rad)
        mr: Rotary arm mass (kg)
        Lr: Rotor arm length (m)
        br: Rotor arm damping coefficient (N-m-s/rad)
        mp: Pendulum mass (kg)
        Lp: Pendulum arm length (m)
        bp: Pendulum damping coefficient (N-m-s/rad)
        g: Gravitational constant (m/s^2)
        kr: Feedback control to send the rotor back to zero
        full_state_output: If True, output the full state vector. Otherwise,
            only output the rotor and pendulum angles.

    """

    def __init__(
        self,
        x0=[0.0, 0.0, 0.0, 0.0],
        Rm=8.4,
        km=0.042,
        mr=0.095,
        Lr=0.085,
        br=5e-4,
        mp=0.024,
        Lp=0.129,
        bp=2.5e-5,
        g=9.81,
        kr=0.0,
        full_state_output=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.declare_dynamic_parameter("Rm", Rm)
        self.declare_dynamic_parameter("km", km)
        self.declare_dynamic_parameter("mr", mr)
        self.declare_dynamic_parameter("Lr", Lr)
        self.declare_dynamic_parameter("br", br)
        self.declare_dynamic_parameter("mp", mp)
        self.declare_dynamic_parameter("Lp", Lp)
        self.declare_dynamic_parameter("bp", bp)
        self.declare_dynamic_parameter("g", g)
        self.declare_dynamic_parameter("kr", kr)

        # One input: motor voltage
        self.declare_input_port()

        self.declare_continuous_state(default_value=x0, ode=self._ode)

        if full_state_output:
            self.declare_continuous_state_output()

        else:
            # Only measure (alpha, theta)
            def _obs_callback(time, state, *inputs, **parameters):
                return state.continuous_state[:2]

            self.declare_output_port(_obs_callback, requires_inputs=False)

    def _ode(self, time, state, *inputs, **parameters):
        # Unpack state
        q, dq = state.continuous_state[:2], state.continuous_state[2:]
        theta, alpha = q  # Rotor angle, pendulum angle
        theta_dot, alpha_dot = dq

        # Unpack parameters
        Rm = parameters["Rm"]
        km = parameters["km"]
        mr = parameters["mr"]
        Lr = parameters["Lr"]
        br = parameters["br"]
        mp = parameters["mp"]
        Lp = parameters["Lp"]
        bp = parameters["bp"]
        kr = parameters["kr"]
        g = parameters["g"]

        lp = Lp / 2  # Pendulum center of mass

        # Moment of inertia of the rotor arm about the motor
        Jr = mr * Lr**2 / 3

        # Moment of inertia of the pendulum about the pivot point
        Jp = mp * Lp**2 / 3

        # Unpack inputs
        (u,) = inputs
        u = cnp.atleast_1d(u)

        # Feedback control to send the rotor back to zero
        u -= kr * theta

        # Mass matrix
        M = cnp.array(
            [
                [Jr + Jp * cnp.sin(alpha) ** 2, -mp * lp * Lr * cnp.cos(alpha)],
                [-mp * lp * Lr * cnp.cos(alpha), Jp],
            ]
        )

        # Coriolis matrix
        C = cnp.array(
            [
                [
                    Jp * cnp.sin(2 * alpha) * alpha_dot + br + 0 * km**2 / Rm,
                    mp * lp * Lr * cnp.sin(alpha) * alpha_dot,
                ],
                [-0.5 * Jp * cnp.sin(2 * alpha) * theta_dot, bp],
            ]
        )

        # Gravity vector
        tau_g = cnp.array([0, mp * g * lp * cnp.sin(alpha)])

        # Input matrix
        B = cnp.array(
            [
                [km / Rm],
                [0],
            ]
        )

        # State space representation
        ddq = cnp.linalg.solve(M, B @ u - (C @ dq + tau_g))
        return cnp.concatenate([dq, ddq])

Ramp

Bases: SourceBlock

Output a linear ramp signal in time.

Given a slope m, a start value y0, and a start time t0, the output signal is:

    y(t) = m * (t - t0) + y0 if t >= t0 else y0

where t is the current simulation time.

Input ports

None

Output ports

(0) The ramp signal.

Parameters:

Name Type Description Default
start_value

The value of the output signal at the start time.

0.0
slope

The slope of the ramp signal.

1.0
start_time

The time at which the ramp signal begins.

1.0
Source code in collimator/library/primitives.py
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
class Ramp(SourceBlock):
    """Output a linear ramp signal in time.

    Given a slope `m`, a start value `y0`, and a start time `t0`, the output signal is:
    ```
        y(t) = m * (t - t0) + y0 if t >= t0 else y0
    ```
    where `t` is the current simulation time.

    Input ports:
        None

    Output ports:
        (0) The ramp signal.

    Parameters:
        start_value:
            The value of the output signal at the start time.
        slope:
            The slope of the ramp signal.
        start_time:
            The time at which the ramp signal begins.
    """

    @parameters(dynamic=["start_value", "slope", "start_time"])
    def __init__(self, start_value=0.0, slope=1.0, start_time=1.0, **kwargs):
        super().__init__(self._func, **kwargs)

    def initialize(self, start_value, slope, start_time):
        pass

    def _func(self, time, **parameters):
        m = parameters["slope"]
        t0 = parameters["start_time"]
        y0 = parameters["start_value"]
        return cnp.where(time >= t0, m * (time - t0) + y0, y0)

RandomNumber

Bases: LeafSystem

Discrete-time random number generator.

Generates independent, identically distributed random numbers at each time step. Dispatches to jax.random for the actual random number generation.

Supported distributions include "ball", "cauchy", "choice", "dirichlet", "exponential", "gamma", "lognormal", "maxwell", "normal", "orthogonal", "poisson", "randint", "truncated_normal", and "uniform".

See https://jax.readthedocs.io/en/latest/jax.random.html#random-samplers for a full list of available distributions and associated parameters.

Although the JAX random number generator is a deterministic function of the key, this block maintains the key as part of the discrete state, making it a stateful RNG. The block can be seeded for reproducibility by passing an integer seed; if None, a random seed will be generated using numpy.random.

Note that this block should typically not be used as a source of randomness for continuous-time systems, as it generates a discrete-time signal. For continuous systems, use a continuous-time noise source, such as WhiteNoise.

Input ports

None

Output ports

(0) The most recently generated random number.

Parameters:

Name Type Description Default
dt float

The rate at which random numbers are generated.

required
distribution str

The name of the random distribution to sample from.

'normal'
seed int

An integer seed for the random number generator. If None, a random 32-bit seed will be generated.

None
dtype DTypeLike

data type of the random number. If None, the default data type for the specified distribution will be used. Not all distributions support all data types; check the JAX documentation for details.

None
distribution_parameters

A dictionary of additional parameters to pass to the distribution function.

{}
Source code in collimator/library/random.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class RandomNumber(LeafSystem):
    """Discrete-time random number generator.

    Generates independent, identically distributed random numbers at each time step.
    Dispatches to `jax.random` for the actual random number generation.

    Supported distributions include "ball", "cauchy", "choice", "dirichlet",
    "exponential", "gamma", "lognormal", "maxwell", "normal", "orthogonal",
    "poisson", "randint", "truncated_normal", and "uniform".

    See https://jax.readthedocs.io/en/latest/jax.random.html#random-samplers
    for a full list of available distributions and associated parameters.

    Although the JAX random number generator is a deterministic function of the
    key, this block maintains the key as part of the discrete state, making it a
    stateful RNG.  The block can be seeded for reproducibility by passing an integer
    seed; if None, a random seed will be generated using numpy.random.

    Note that this block should typically not be used as a source of randomness for
    continuous-time systems, as it generates a discrete-time signal. For continuous
    systems, use a continuous-time noise source, such as `WhiteNoise`.

    Input ports:
        None

    Output ports:
        (0) The most recently generated random number.

    Parameters:
        dt: The rate at which random numbers are generated.
        distribution: The name of the random distribution to sample from.
        seed: An integer seed for the random number generator. If None, a random 32-bit
            seed will be generated.
        dtype: data type of the random number.  If None, the default data type for the
            specified distribution will be used.  Not all distributions support all
            data types; check the JAX documentation for details.
        distribution_parameters: A dictionary of additional parameters to pass to the
            distribution function.
    """

    class RNGState(NamedTuple):
        key: Array
        val: Array

    @parameters(static=["distribution", "seed", "shape"])
    def __init__(
        self,
        dt: float,
        distribution: str = "normal",  # UI only exposes 'normal' for now
        seed: int = None,
        dtype: DTypeLike = None,
        shape: ShapeLike = (),
        name: str = None,
        ui_id: str = None,
        **distribution_parameters,
    ):
        super().__init__(name=name, ui_id=ui_id)

        # Declare config parameters for serialization
        self.declare_static_parameters(**distribution_parameters)

        # Add to the data type if specified.  Since not all distributions
        # support this parameter (though most do), we don't want to do this
        # unconditionally.
        if dtype is not None:
            distribution_parameters["dtype"] = dtype

        self.declare_output_port(
            self._output,
            period=dt,
            offset=0.0,
        )

        self.declare_periodic_update(
            self._update,
            period=dt,
            offset=0.0,
        )

    def initialize(
        self,
        distribution: str = "normal",  # UI only exposes 'normal' for now
        seed: int = None,
        shape: ShapeLike = (),
        **distribution_parameters,
    ):
        # Supposedly all distributions support the shape parameter
        if shape is not None and shape != ():
            distribution_parameters["shape"] = shape

        self.rng = partial(getattr(random, distribution), **distribution_parameters)

        key = random.PRNGKey(
            np.random.randint(0, 2**32, dtype=np.int64) if seed is None else seed
        )

        # The discrete state is a tuple of (key, val) pairs.  Because of the way that
        # JAX maintains RNG state, we need to keep track of the key as well as the
        # most recently generated value.
        key, subkey = random.split(key)
        default_state = self.RNGState(
            key=key,
            val=self.rng(subkey),  # Initial random number with the right data type
        )
        self.declare_discrete_state(default_value=default_state, as_array=False)

    def _output(self, _time, state, *_inputs, **_parameters):
        return state.discrete_state.val

    def _update(self, _time, state, *_inputs, **_parameters):
        key, subkey = random.split(state.discrete_state.key)
        return self.RNGState(
            key=key,
            val=self.rng(subkey),
        )

RateLimiter

Bases: LeafSystem

Limit the time derivative of the block output.

Given an input signal u computes the derivative of the output signal as:

    y_rate = (u(t) - y(Tprev))/(t - Tprev)

Where Tprev is the last time the block was called for output update.

When y_rate is greater than the upper_limit, the output is:

    y(t) = (t - Tprev)*upper_limit + y(Tprev)

When y_rate is less than the lower_limit, the output is:

    y(t) = (t - Tprev)*lower_limit + y(Tprev)

If the lower_limit is greater than the upper_limit, and both are being violated, the upper_limit takes precedence.

Optionally, the block can also be configured with "dynamic" limits, which will add input ports for time-varying upper and lower limits.

Presently, the block is constrainted to periodic updates.

Input ports

(0) The input signal. (1) The upper limit, if dynamic limits are enabled. (2) The lower limit, if dynamic limits are enabled. (Will be indexed as 1 if dynamic upper limits are not enabled.)

Output ports

(0) The rate limited output signal.

Parameters:

Name Type Description Default
upper_limit

The upper limit of the input signal. Default is np.inf.

inf
enable_dynamic_upper_limit

If True, then the upper limit can be set by an external signal. Default is False.

False
lower_limit

The lower limit of the input signal. Default is -np.inf.

-inf
enable_dynamic_lower_limit

If True, then the lower limit can be set by an external signal. Default is False.

False
Source code in collimator/library/primitives.py
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
class RateLimiter(LeafSystem):
    """Limit the time derivative of the block output.

    Given an input signal `u` computes the derivative of the output signal as:
    ```
        y_rate = (u(t) - y(Tprev))/(t - Tprev)
    ```
    Where Tprev is the last time the block was called for output update.

    When y_rate is greater than the upper_limit, the output is:
    ```
        y(t) = (t - Tprev)*upper_limit + y(Tprev)
    ```

    When y_rate is less than the lower_limit, the output is:
    ```
        y(t) = (t - Tprev)*lower_limit + y(Tprev)
    ```

    If the lower_limit is greater than the upper_limit, and both
    are being violated, the upper_limit takes precedence.

    Optionally, the block can also be configured with "dynamic" limits, which will
    add input ports for time-varying upper and lower limits.

    Presently, the block is constrainted to periodic updates.

    Input ports:
        (0) The input signal.
        (1) The upper limit, if dynamic limits are enabled.
        (2) The lower limit, if dynamic limits are enabled. (Will be indexed as 1 if
            dynamic upper limits are not enabled.)

    Output ports:
        (0) The rate limited output signal.

    Parameters:
        upper_limit:
            The upper limit of the input signal.  Default is `np.inf`.
        enable_dynamic_upper_limit:
            If True, then the upper limit can be set by an external signal. Default
            is False.
        lower_limit:
            The lower limit of the input signal.  Default is `-np.inf`.
        enable_dynamic_lower_limit:
            If True, then the lower limit can be set by an external signal. Default
            is False.
    """

    class DiscreteStateType(NamedTuple):
        y_prev: Array
        t_prev: float

    @parameters(
        static=["enable_dynamic_upper_limit", "enable_dynamic_lower_limit"],
        dynamic=["upper_limit", "lower_limit"],
    )
    def __init__(
        self,
        dt,
        upper_limit=np.inf,
        enable_dynamic_upper_limit=False,
        lower_limit=-np.inf,
        enable_dynamic_lower_limit=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.primary_input_index = self.declare_input_port()
        self.enable_dynamic_upper_limit = enable_dynamic_upper_limit
        self.enable_dynamic_lower_limit = enable_dynamic_lower_limit
        self.dt = dt

        if enable_dynamic_upper_limit:
            # If dynamic limit, simply ignore the static limit
            self.upper_limit_index = self.declare_input_port()

        if enable_dynamic_lower_limit:
            # If dynamic limit, simply ignore the static limit
            self.lower_limit_index = self.declare_input_port()

        self.output_index = self.declare_output_port(
            self._output,
            period=dt,
            offset=0.0,
        )

    def initialize(
        self,
        upper_limit=np.inf,
        enable_dynamic_upper_limit=False,
        lower_limit=-np.inf,
        enable_dynamic_lower_limit=False,
    ):
        if enable_dynamic_upper_limit != self.enable_dynamic_upper_limit:
            raise ValueError(
                "RateLimiter: enable_dynamic_upper_limit cannot be changed after initialization"
            )
        if enable_dynamic_lower_limit != self.enable_dynamic_lower_limit:
            raise ValueError(
                "RateLimiter: enable_dynamic_lower_limit cannot be changed after initialization"
            )

    def _output(self, time, state, *inputs, **params):
        y_prev = state.cache[self.output_index]

        u = inputs[self.primary_input_index]

        t_diff = self.dt

        y_rate = (u - y_prev) / t_diff

        ulim = (
            inputs[self.upper_limit_index]
            if self.enable_dynamic_upper_limit
            else params["upper_limit"]
        )
        llim = (
            inputs[self.lower_limit_index]
            if self.enable_dynamic_lower_limit
            else params["lower_limit"]
        )
        y_ulim = t_diff * ulim + y_prev
        y_llim = t_diff * llim + y_prev
        y_tmp = cnp.where(y_rate < llim, y_llim, u)
        y = cnp.where(y_rate > ulim, y_ulim, y_tmp)

        return y

    def initialize_static_data(self, context):
        """Infer the size and dtype of the internal states"""
        # If building as part of a subsystem, this may not be fully connected yet.
        # That's fine, as long as it is connected by root context creation time.
        # This probably isn't a good long-term solution:
        #   see https://collimator.atlassian.net/browse/WC-51
        try:
            u = self.eval_input(context)
            self._default_cache[self.output_index] = u
            local_context = context[self.system_id].with_discrete_state(u)
            local_context = local_context.with_cached_value(self.output_index, u)
            context = context.with_subcontext(self.system_id, local_context)

        except UpstreamEvalError:
            logger.debug(
                "RateLimiter.initialize_static_data: UpstreamEvalError. "
                "Continuing without default value initialization."
            )
        return context

initialize_static_data(context)

Infer the size and dtype of the internal states

Source code in collimator/library/primitives.py
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
def initialize_static_data(self, context):
    """Infer the size and dtype of the internal states"""
    # If building as part of a subsystem, this may not be fully connected yet.
    # That's fine, as long as it is connected by root context creation time.
    # This probably isn't a good long-term solution:
    #   see https://collimator.atlassian.net/browse/WC-51
    try:
        u = self.eval_input(context)
        self._default_cache[self.output_index] = u
        local_context = context[self.system_id].with_discrete_state(u)
        local_context = local_context.with_cached_value(self.output_index, u)
        context = context.with_subcontext(self.system_id, local_context)

    except UpstreamEvalError:
        logger.debug(
            "RateLimiter.initialize_static_data: UpstreamEvalError. "
            "Continuing without default value initialization."
        )
    return context

Reciprocal

Bases: FeedthroughBlock

Compute the reciprocal of the input signal.

Input ports

(0) The input signal.

Output ports

(0) The reciprocal of the input signal: y = 1 / u.

Source code in collimator/library/primitives.py
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
class Reciprocal(FeedthroughBlock):
    """Compute the reciprocal of the input signal.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The reciprocal of the input signal: `y = 1 / u`.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(lambda x: 1 / x, *args, **kwargs)

ReferenceSubdiagram

Source code in collimator/library/reference_subdiagram.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class ReferenceSubdiagram:
    # TODO: improve documentation here.
    _registry: dict[str, Callable[[Any], "Diagram"]] = {}
    _parameter_definitions: dict[str, list[Parameter]] = {}  # noqa: F821

    @classmethod
    def create_diagram(
        cls,
        ref_id: str,
        instance_name: str,
        *args,
        instance_parameters: dict[str, Any] = None,
        **kwargs,
    ) -> "Diagram":
        """
        Create a diagram based on the given reference ID and parameters.

        Note that for submodels we evaluate all parameters, there is no
        "pure" string parameters.

        Args:
            ref_id (str): The reference ID of the diagram.
            *args: Variable length arguments.
            instance_parameters (dict[str, Any], optional): The instance parameters for the diagram. Defaults to None.
                example: {"gain": 3.0}
            **kwargs: Keyword arguments.

        Returns:
            Diagram: The created diagram.

        Raises:
            ValueError: If the reference subdiagram with the given ref_id is not found.
        """
        if ref_id not in ReferenceSubdiagram._registry:
            raise ValueError(f"ReferenceSubdiagram with ref_id {ref_id} not found.")

        params_def = ReferenceSubdiagram.get_parameter_definitions(ref_id)

        default_params = {p.name: p for p in params_def}

        # override the default values with any 'modified' values.
        new_instance_parameters = {}
        if instance_parameters:
            for param_name, param in instance_parameters.items():
                if param_name not in default_params:
                    raise ValueError(
                        f"Parameter {param_name} not found in parameter definitions."
                    )
                new_instance_parameters[param_name] = Parameter(
                    name=param_name, value=param
                )

        all_params = {**default_params, **new_instance_parameters}

        diagram = ReferenceSubdiagram._registry[ref_id](
            *args,
            instance_name=instance_name,
            parameters=all_params,
            **kwargs,
        )

        diagram.ref_id = ref_id
        diagram.instance_parameters = set(new_instance_parameters.keys())

        for param in params_def:
            if param.name in new_instance_parameters:
                diagram.declare_dynamic_parameter(
                    param.name, new_instance_parameters[param.name]
                )
            else:
                diagram.declare_dynamic_parameter(param.name, param)

        return diagram

    @staticmethod
    def register(
        constructor: ReferenceSubdiagramProtocol,
        # FIXME: rename parameter_definitions to default_parameters
        parameter_definitions: list[Parameter] = None,  # noqa: F821
        ref_id: str | None = None,
    ) -> str:
        if ref_id is None:
            ref_id = str(uuid4())
        if parameter_definitions is None:
            parameter_definitions = []

        logger.debug("Registering ReferenceSubdiagram with ref_id %s", ref_id)
        if ref_id in ReferenceSubdiagram._registry:
            logger.debug(
                "ReferenceSubdiagram with ref_id %s already registered.",
                ref_id,
            )

        ReferenceSubdiagram._registry[ref_id] = constructor
        ReferenceSubdiagram._parameter_definitions[ref_id] = parameter_definitions

        return ref_id

    @staticmethod
    def get_parameter_definitions(
        ref_id: str,
    ) -> list[Parameter]:  # noqa: F821
        if ref_id not in ReferenceSubdiagram._parameter_definitions:
            return []
        return ReferenceSubdiagram._parameter_definitions[ref_id]

create_diagram(ref_id, instance_name, *args, instance_parameters=None, **kwargs) classmethod

Create a diagram based on the given reference ID and parameters.

Note that for submodels we evaluate all parameters, there is no "pure" string parameters.

Parameters:

Name Type Description Default
ref_id str

The reference ID of the diagram.

required
*args

Variable length arguments.

()
instance_parameters dict[str, Any]

The instance parameters for the diagram. Defaults to None. example: {"gain": 3.0}

None
**kwargs

Keyword arguments.

{}

Returns:

Name Type Description
Diagram Diagram

The created diagram.

Raises:

Type Description
ValueError

If the reference subdiagram with the given ref_id is not found.

Source code in collimator/library/reference_subdiagram.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@classmethod
def create_diagram(
    cls,
    ref_id: str,
    instance_name: str,
    *args,
    instance_parameters: dict[str, Any] = None,
    **kwargs,
) -> "Diagram":
    """
    Create a diagram based on the given reference ID and parameters.

    Note that for submodels we evaluate all parameters, there is no
    "pure" string parameters.

    Args:
        ref_id (str): The reference ID of the diagram.
        *args: Variable length arguments.
        instance_parameters (dict[str, Any], optional): The instance parameters for the diagram. Defaults to None.
            example: {"gain": 3.0}
        **kwargs: Keyword arguments.

    Returns:
        Diagram: The created diagram.

    Raises:
        ValueError: If the reference subdiagram with the given ref_id is not found.
    """
    if ref_id not in ReferenceSubdiagram._registry:
        raise ValueError(f"ReferenceSubdiagram with ref_id {ref_id} not found.")

    params_def = ReferenceSubdiagram.get_parameter_definitions(ref_id)

    default_params = {p.name: p for p in params_def}

    # override the default values with any 'modified' values.
    new_instance_parameters = {}
    if instance_parameters:
        for param_name, param in instance_parameters.items():
            if param_name not in default_params:
                raise ValueError(
                    f"Parameter {param_name} not found in parameter definitions."
                )
            new_instance_parameters[param_name] = Parameter(
                name=param_name, value=param
            )

    all_params = {**default_params, **new_instance_parameters}

    diagram = ReferenceSubdiagram._registry[ref_id](
        *args,
        instance_name=instance_name,
        parameters=all_params,
        **kwargs,
    )

    diagram.ref_id = ref_id
    diagram.instance_parameters = set(new_instance_parameters.keys())

    for param in params_def:
        if param.name in new_instance_parameters:
            diagram.declare_dynamic_parameter(
                param.name, new_instance_parameters[param.name]
            )
        else:
            diagram.declare_dynamic_parameter(param.name, param)

    return diagram

Relay

Bases: LeafSystem

Simple state machine implementing hysteresis behavior.

The input-output map is as follows:

        output
          |
on_value  |          -------<------<---------------------
          |          |                    |
          |          ⌄                    ^
          |          |                    |
off_value |----------|-------->----->-----|
          |
          |---------------------------------------------- input
                     | off_threshold      | on_threshold

Note that the "time mode" behavior of this block will follow the input signal. That is, if the input signal varies continuously in time, then the zero-crossing event from OFF->ON or vice versa will be localized in time. On the other hand, if the input signal varies only as a result of periodic updates to the discrete state, the relay will only change state at those instants. If the input signal is continuous, the block can be "forced" to this discrete-time periodic behavior by adding a ZeroOrderHold block before the input.

The exception to this is the case where there are no blocks in the system containing either discrete or continuous state. In this case the state changes will only be localized to the resolution of the major step.

Input ports

(0) The input signal.

Output ports

(0) The relay output signal, which is equal to either the on_value or the off_value, depending on the internal state of the relay.

Parameters:

Name Type Description Default
on_threshold

When input rises above this value, the internal state transitions to ON.

required
off_threshold

When input falls below this value, the internal state transitions to OFF.

required
on_value

Value of the output signal when state is ON.

required
off_value

Value of the output signal when state is OFF

required
initial_state

If equal to on_value, the block will be initialized in the ON state. Otherwise, it will be initialized to the OFF state.

required
Events

There are two zero-crossing events: one to transition from OFF->ON and one for the opposite transition from ON->OFF.

Source code in collimator/library/primitives.py
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
class Relay(LeafSystem):
    """Simple state machine implementing hysteresis behavior.

    The input-output map is as follows:

    ```
            output
              |
    on_value  |          -------<------<---------------------
              |          |                    |
              |          ⌄                    ^
              |          |                    |
    off_value |----------|-------->----->-----|
              |
              |---------------------------------------------- input
                         | off_threshold      | on_threshold
    ```

    Note that the "time mode" behavior of this block will follow the input
    signal.  That is, if the input signal varies continuously in time, then
    the zero-crossing event from OFF->ON or vice versa will be localized in
    time.  On the other hand, if the input signal varies only as a result
    of periodic updates to the discrete state, the relay will only change state
    at those instants.  If the input signal is continuous, the block can
    be "forced" to this discrete-time periodic behavior by adding a ZeroOrderHold
    block before the input.

    The exception to this is the case where there are no blocks in the system
    containing either discrete or continuous state.  In this case the state changes
    will only be localized to the resolution of the major step.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The relay output signal, which is equal to either the on_value or
            the off_value, depending on the internal state of the relay.

    Parameters:
        on_threshold:
            When input rises above this value, the internal state transitions to ON.
        off_threshold:
            When input falls below this value, the internal state transitions to OFF.
        on_value:
            Value of the output signal when state is ON.
        off_value:
            Value of the output signal when state is OFF
        initial_state:
            If equal to on_value, the block will be initialized in the ON state.
            Otherwise, it will be initialized to the OFF state.

    Events:
        There are two zero-crossing events: one to transition from OFF->ON and one
        for the opposite transition from ON->OFF.
    """

    class State(IntEnum):
        OFF = 0
        ON = 1

    @parameters(
        dynamic=[
            "on_threshold",
            "off_threshold",
            "initial_state",
            "on_value",
            "off_value",
        ],
    )
    def __init__(
        self, on_threshold, off_threshold, on_value, off_value, initial_state, **kwargs
    ):
        super().__init__(**kwargs)

        self.declare_default_mode(
            self.State.ON if initial_state == on_value else self.State.OFF
        )

        self.declare_input_port()
        self.declare_output_port(
            self._output,
            requires_inputs=False,
            prerequisites_of_calc=[DependencyTicket.mode],
        )

        # transition to ON event
        def _on_guard(_time, _state, u, **parameters):
            return u - parameters["on_threshold"]

        self.declare_zero_crossing(
            guard=_on_guard,
            direction="negative_then_non_negative",
            start_mode=self.State.OFF,
            end_mode=self.State.ON,
        )

        # transition to OFF event
        def _off_guard(_time, _state, u, **parameters):
            return u - parameters["off_threshold"]

        self.declare_zero_crossing(
            guard=_off_guard,
            direction="positive_then_non_positive",
            start_mode=self.State.ON,
            end_mode=self.State.OFF,
        )

    def initialize(
        self, on_threshold, off_threshold, on_value, off_value, initial_state
    ):
        self.configure_default_mode(
            self.State.ON if initial_state == on_value else self.State.OFF
        )

    def reset_default_values(self, **dynamic_parameters):
        self.configure_default_mode(
            self.State.ON
            if dynamic_parameters["initial_state"] == dynamic_parameters["on_value"]
            else self.State.OFF
        )

    def _output(self, _time, state, **parameters):
        return cnp.where(
            state.mode == self.State.ON,
            parameters["on_value"],
            parameters["off_value"],
        )

RigidBody

Bases: LeafSystem

Implements dynamics of a single three-dimensional body.

The block models both translational and rotational degrees of freedom, for a total of 6 degrees of freedom. With second-order equations, the block has 12 state variables, 6 for the position/orientation and 6 for the velocities/rates.

Currently only a roll-pitch-yaw (Euler angle) representation is supported for the orientation.

The full 12-dof state vector is x = [p_i, Φ, vᵇ, ωᵇ], where pⁱ is the position in an inertial "world" frame i, Φ is the (roll, pitch, and yaw) Euler angle sequence defining the rotation from the inertial "world" frame to the body frame, vᵇ is the translational velocity with respect to body-fixed axes b, and ωᵇ is the angular velocity about the body-fixed axes.

The mass and inertia properties of the block can independently be defined statically as parameters, or dynamically as inputs to the block.

Input ports

(0) force_vector: 3D force vector, defined in the body-fixed coordinate frame. For example, if gravity is acting on the body, the gravity vector should be pre-rotated using CoordinateRotation.

(1) torque_vector: 3D torque vector, be defined in the body-fixed coordinate frame.

(2) inertia: If enable_external_inertia_matrix=True, this input provides the time-varying body-fixed inertia matrix.

Output ports

(0): The position in the inertial "world" frame pⁱ.

(1): The orientation of the body, represented as a roll-pitch-yaw Euler angle sequence.

(2): The translational velocity with respect to body-fixed axes vᵇ.

(3): The angular velocity about the body-fixed axes ωᵇ.

(4): (if enable_output_state_derivatives=True) The time derivatives of the position variables in the world frame ṗⁱ. Not generally equal to the state vᵇ, defining time derivatives in the body frame.

(5): (if enable_output_state_derivatives=True) The "Euler rates" Φ̇, which are the time derivatives of the Euler angles. Not generally equal to the angular velocity ωᵇ.

(6): (if enable_output_state_derivatives=True) The body-fixed acceleration vector aᵇ.

(7): (if enable_output_state_derivatives=True) The angular acceleration in body-fixed axes ω̇ᵇ.

Parameters:

Name Type Description Default
initial_position Array

The initial position in the inertial frame.

required
initial_orientation Array

The initial orientation of the body, represented as a roll-pitch-yaw Euler angle sequence.

required
initial_velocity Array

The initial translational velocity with respect to body-fixed axes.

required
initial_angular_velocity Array

The initial angular velocity about the body-fixed axes.

required
enable_external_mass bool

If True, the block will have one input port for the mass. Otherwise the mass must be provided as a block parameter.

False
mass float

The constant value for the body mass when enable_external_mass=False. If None, will default to 1.0.

1.0
enable_external_inertia_matrix bool

If True, the block will have one input port for a (3x3) inertia matrix. Otherwise the inertia matrix must be provided as a block parameter.

False
inertia_matrix

The constant value for the body inertia matrix when enable_external_inertia_matrix=False. If None, will default to the 3x3 identity matrix.

eye(3)
enable_output_state_derivatives bool

If True, the block will output the time derivatives of the state variables.

False
gravity_vector Array

The constant gravitational acceleration vector acting on the body, defined in the inertial frame. If None, will default to the zero vector.

zeros(3)
Notes

Assumes that the inertia matrix is computed at the center of mass.

Assumes that the mass and inertia matrix are quasi-steady. This means that if one or both is specified as "dynamic" inputs their time derivative is neglected in the dynamics. For instance, for pure translation (w_b=0) the approximation to Newton's law is F_net = (d/dt)(m * v) ≈ m * (dv/dt).

Source code in collimator/library/rotations.py
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
class RigidBody(LeafSystem):
    """Implements dynamics of a single three-dimensional body.

    The block models both translational and rotational degrees of freedom, for a
    total of 6 degrees of freedom.  With second-order equations, the block has
    12 state variables, 6 for the position/orientation and 6 for the velocities/rates.

    Currently only a roll-pitch-yaw (Euler angle) representation is supported for
    the orientation.

    The full 12-dof state vector is `x = [p_i, Φ, vᵇ, ωᵇ]`, where `pⁱ` is the
    position in an inertial "world" frame `i`, `Φ` is the (roll, pitch, and yaw)
    Euler angle sequence defining the rotation from the inertial "world" frame to
    the body frame, `vᵇ` is the translational velocity with respect to body-fixed
    axes `b`, and `ωᵇ` is the angular velocity about the body-fixed axes.

    The mass and inertia properties of the block can independently be defined
    statically as parameters, or dynamically as inputs to the block.

    Input ports:
        (0) force_vector: 3D force vector, defined in the _body-fixed_ coordinate
        frame.  For example, if gravity is acting on the body, the gravity vector
        should be pre-rotated using CoordinateRotation.

        (1) torque_vector: 3D torque vector, be defined in the _body-fixed_
        coordinate frame.

        (2) inertia: If `enable_external_inertia_matrix=True`, this input provides
        the time-varying body-fixed inertia matrix.

    Output ports:
        (0): The position in the inertial "world" frame `pⁱ`.

        (1): The orientation of the body, represented as a roll-pitch-yaw Euler
        angle sequence.

        (2): The translational velocity with respect to body-fixed axes `vᵇ`.

        (3): The angular velocity about the body-fixed axes `ωᵇ`.

        (4): (if `enable_output_state_derivatives=True`) The time derivatives of the
        position variables in the world frame `ṗⁱ`. Not generally equal to the state
        `vᵇ`, defining time derivatives in the body frame.

        (5): (if `enable_output_state_derivatives=True`) The "Euler rates" `Φ̇`,
        which are the time derivatives of the Euler angles. Not generally equal to
        the angular velocity `ωᵇ`.

        (6): (if `enable_output_state_derivatives=True`) The body-fixed acceleration
        vector `aᵇ`.

        (7): (if `enable_output_state_derivatives=True`) The angular acceleration in
        body-fixed axes `ω̇ᵇ`.

    Parameters:
        initial_position (Array): The initial position in the inertial frame.

        initial_orientation (Array): The initial orientation of the body, represented
            as a roll-pitch-yaw Euler angle sequence.

        initial_velocity (Array): The initial translational velocity with respect to
            body-fixed axes.

        initial_angular_velocity (Array): The initial angular velocity about the
            body-fixed axes.

        enable_external_mass (bool, optional): If `True`, the block will have one
            input port for the mass. Otherwise the mass must be provided as a block
            parameter.

        mass (float, optional): The constant value for the body mass when
            `enable_external_mass=False`. If `None`, will default to 1.0.

        enable_external_inertia_matrix (bool, optional):  If `True`, the block will
            have one input port for a (3x3) inertia matrix. Otherwise the inertia
            matrix must be provided as a block parameter.

        inertia_matrix: The constant value for the body inertia matrix when
            `enable_external_inertia_matrix=False`. If `None`, will default to
            the 3x3 identity matrix.

        enable_output_state_derivatives (bool, optional): If `True`, the block will
            output the time derivatives of the state variables.

        gravity_vector (Array, optional): The constant gravitational acceleration vector
            acting on the body, defined in the _inertial_ frame. If `None`, will default
            to the zero vector.

    Notes:
        Assumes that the inertia matrix is computed at the center of mass.

        Assumes that the mass and inertia matrix are quasi-steady.  This means that
        if one or both is specified as "dynamic" inputs their time derivative is
        neglected in the dynamics.  For instance, for pure translation (`w_b=0`) the
        approximation to Newton's law is `F_net = (d/dt)(m * v) ≈ m * (dv/dt)`.
    """

    class RigidBodyState(NamedTuple):
        position: Array
        orientation: Array
        velocity: Array
        angular_velocity: Array

        def asarray(self):
            return cnp.concatenate(
                [self.position, self.orientation, self.velocity, self.angular_velocity]
            )

    @parameters(
        static=[
            "initial_position",
            "initial_orientation",
            "initial_velocity",
            "initial_angular_velocity",
            "enable_external_mass",
            "enable_external_inertia_matrix",
            "enable_output_state_derivatives",
        ],
        dynamic=["mass", "inertia_matrix", "gravity_vector"],
    )
    def __init__(
        self,
        initial_position,
        initial_orientation,
        initial_velocity,
        initial_angular_velocity,
        enable_external_mass=False,
        mass=1.0,
        enable_external_inertia_matrix=False,
        inertia_matrix=cnp.eye(3),
        enable_output_state_derivatives=False,
        gravity_vector=cnp.zeros(3),
        **kwargs,
    ):
        super().__init__(**kwargs)

        self._enable_external_mass = enable_external_mass
        self._enable_external_inertia_matrix = enable_external_inertia_matrix
        self._enable_output_state_derivatives = enable_output_state_derivatives

        initial_state = self._make_initial_state(
            initial_position,
            initial_orientation,
            initial_velocity,
            initial_angular_velocity,
        )

        self._continuous_state_idx = self.declare_continuous_state(
            default_value=initial_state,
            as_array=False,
            ode=self._state_derivative,
        )

        self._configure_ports(
            initial_state,
            enable_external_mass,
            enable_external_inertia_matrix,
            enable_output_state_derivatives,
        )

    def initialize(
        self,
        initial_position,
        initial_orientation,
        initial_velocity,
        initial_angular_velocity,
        enable_external_mass,
        enable_external_inertia_matrix,
        enable_output_state_derivatives,
        mass,
        inertia_matrix,
        gravity_vector,
    ):
        if enable_external_mass != self._enable_external_mass:
            raise ValueError("Cannot change external mass definition.")
        if enable_external_inertia_matrix != self._enable_external_inertia_matrix:
            raise ValueError("Cannot change external inertia matrix definition.")
        if enable_output_state_derivatives != self._enable_output_state_derivatives:
            raise ValueError("Cannot change output state derivatives definition.")

        gravity_vector = cnp.asarray(gravity_vector)
        if gravity_vector.shape != (3,):
            message = (
                "Gravity vector must have shape (3,), but has shape "
                + f"{gravity_vector.shape}."
            )
            raise BlockParameterError(
                message=message, system=self, parameter_name="gravity_vector"
            )

        initial_state = self._make_initial_state(
            initial_position,
            initial_orientation,
            initial_velocity,
            initial_angular_velocity,
        )

        self.configure_continuous_state(
            self._continuous_state_idx,
            default_value=initial_state,
            as_array=False,
            ode=self._state_derivative,
        )

        self.configure_output_port(
            self.pos_output_index,
            self._pos_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.position,
        )

        self.configure_output_port(
            self.orientation_output_index,
            self._orientation_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.orientation,
        )

        self.configure_output_port(
            self.vel_output_index,
            self._vel_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.velocity,
        )

        self.configure_output_port(
            self.ang_vel_output_index,
            self._ang_vel_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.angular_velocity,
        )

    def _make_initial_state(
        self,
        initial_position,
        initial_orientation,
        initial_velocity,
        initial_angular_velocity,
    ):
        # Validate initial state arrays and create named tuple for initial state.
        initial_position = cnp.asarray(initial_position)
        if initial_position.shape != (3,):
            message = (
                "Initial position must have shape (3,), but has shape "
                + f"{initial_position.shape}."
            )
            raise BlockParameterError(
                message=message, system=self, parameter_name="initial_position"
            )

        initial_orientation = cnp.asarray(initial_orientation)
        if initial_orientation.shape != (3,):
            message = (
                "Initial orientation must have shape (3,), but has shape "
                + f"{initial_orientation.shape}."
            )
            raise BlockParameterError(
                message=message, system=self, parameter_name="initial_orientation"
            )

        initial_velocity = cnp.asarray(initial_velocity)
        if initial_velocity.shape != (3,):
            message = (
                "Initial velocity must have shape (3,), but has shape "
                + f"{initial_velocity.shape}."
            )
            raise BlockParameterError(
                message=message, system=self, parameter_name="initial_velocity"
            )

        initial_angular_velocity = cnp.asarray(initial_angular_velocity)
        if initial_angular_velocity.shape != (3,):
            message = (
                "Initial angular velocity must have shape (3,), but has shape "
                + f"{initial_angular_velocity.shape}."
            )
            raise BlockParameterError(
                message=message, system=self, parameter_name="initial_angular_velocity"
            )

        return self.RigidBodyState(
            position=initial_position,
            orientation=initial_orientation,
            velocity=initial_velocity,
            angular_velocity=initial_angular_velocity,
        )

    @property
    def force_input(self):
        return self.input_ports[self.force_index]

    @property
    def torque_input(self):
        return self.input_ports[self.torque_index]

    @property
    def mass_input(self):
        if self.mass_index is None:
            return None
        return self.input_ports[self.mass_index]

    @property
    def inertia_input(self):
        if self.inertia_index is None:
            return None
        return self.input_ports[self.inertia_index]

    @property
    def position_output(self):
        return self.output_ports[self.pos_output_index]

    @property
    def orientation_output(self):
        return self.output_ports[self.orientation_output_index]

    @property
    def velocity_output(self):
        return self.output_ports[self.vel_output_index]

    @property
    def angular_velocity_output(self):
        return self.output_ports[self.ang_vel_output_index]

    def _configure_ports(
        self,
        initial_state,
        enable_external_mass,
        enable_external_inertia_matrix,
        enable_output_state_derivatives,
    ):
        # External force vector input
        self.force_index = self.declare_input_port(name="force_vector")

        # External torque vector input
        self.torque_index = self.declare_input_port(name="torque_vector")

        # External mass input
        self.mass_index = None
        if enable_external_mass:
            self.mass_index = self.declare_input_port(name="mass")

        # External inertia matrix input
        self.inertia_index = None
        if enable_external_inertia_matrix:
            self.inertia_index = self.declare_input_port(name="inertia_matrix")

        # Position output
        self.pos_output_index = self.declare_output_port(
            self._pos_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.position,
            name=f"{self.name}:position",
        )

        # Orientation output
        self.orientation_output_index = self.declare_output_port(
            self._orientation_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.orientation,
            name=f"{self.name}:orientation",
        )

        # Velocity output
        self.vel_output_index = self.declare_output_port(
            self._vel_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.velocity,
            name=f"{self.name}:velocity",
        )

        # Angular velocity output
        self.ang_vel_output_index = self.declare_output_port(
            self._ang_vel_output,
            prerequisites_of_calc=[DependencyTicket.xc],
            requires_inputs=False,
            default_value=initial_state.angular_velocity,
            name=f"{self.name}:angular_velocity",
        )

        if enable_output_state_derivatives:
            self.pos_deriv_output_index = self.declare_output_port(
                self._pos_derivative,
                prerequisites_of_calc=[DependencyTicket.xc],
                requires_inputs=False,
                default_value=cnp.zeros(3),
                name=f"{self.name}:position_dot",
            )

            self.orientation_deriv_output_index = self.declare_output_port(
                self._orientation_derivative,
                prerequisites_of_calc=[DependencyTicket.xc],
                requires_inputs=False,
                default_value=cnp.zeros(3),
                name=f"{self.name}:orientation_dot",
            )

            force_ticket = self.input_ports[self.force_index].ticket
            self.vel_deriv_output_index = self.declare_output_port(
                self._vel_derivative,
                prerequisites_of_calc=[force_ticket, DependencyTicket.xc],
                requires_inputs=True,
                default_value=cnp.zeros(3),
                name=f"{self.name}:velocity_dot",
            )

            torque_ticket = self.input_ports[self.torque_index].ticket
            self.ang_vel_deriv_output_index = self.declare_output_port(
                self._ang_vel_derivative,
                prerequisites_of_calc=[torque_ticket, DependencyTicket.xc],
                requires_inputs=True,
                default_value=cnp.zeros(3),
                name=f"{self.name}:angular_velocity_dot",
            )

    def _pos_output(self, time, state, *inputs, **parameters):
        xc = state.continuous_state
        return xc.position

    def _orientation_output(self, time, state, *inputs, **parameters):
        xc = state.continuous_state
        return xc.orientation

    def _vel_output(self, time, state, *inputs, **parameters):
        xc = state.continuous_state
        return xc.velocity

    def _ang_vel_output(self, time, state, *inputs, **parameters):
        xc = state.continuous_state
        return xc.angular_velocity

    def _pos_derivative(self, time, state, *inputs, **parameters):
        # This function produces the inertial -> body rotation.  What we
        # want is to rotate the body-fixed velocity into the inertial frame,
        # so we need the transpose of this rotation matrix.
        xc = state.continuous_state
        C_BI = euler_to_dcm(xc.orientation)
        return C_BI.T @ xc.velocity

    def _orientation_derivative(self, time, state, *inputs, **parameters):
        # Matrix mapping angular velocity in the body-fixed frame to Euler rates
        xc = state.continuous_state
        H = euler_kinematics(xc.orientation)
        return H @ xc.angular_velocity

    def _vel_derivative(self, time, state, *inputs, **parameters):
        xc = state.continuous_state

        if self.mass_index is not None:
            m = inputs[self.mass_index]
        else:
            m = parameters["mass"]

        # Gravity vector in the inertial frame
        g_I = parameters["gravity_vector"]

        # Acceleration in body-fixed frame
        F_B = inputs[self.force_index]
        C_BI = euler_to_dcm(xc.orientation)
        a_B = F_B / m + C_BI @ g_I

        # Body-fixed acceleration is the inertial plus Coriolis terms
        return a_B - cnp.cross(xc.angular_velocity, xc.velocity)

    def _ang_vel_derivative(self, time, state, *inputs, **parameters):
        xc = state.continuous_state

        if self.inertia_index is not None:
            J_B = inputs[self.inertia_index]
        else:
            J_B = parameters["inertia_matrix"]

        # Torque in body-fixed frame
        tau_B = inputs[self.torque_index]

        wJw = cnp.cross(xc.angular_velocity, J_B @ xc.angular_velocity)
        return cnp.linalg.solve(J_B, tau_B - wJw)

    def _state_derivative(self, time, state, *inputs, **parameters):
        # See Eq. (1.7-18) in Lewis, Johnson, Stevens
        args = (time, state, *inputs)
        return self.RigidBodyState(
            position=self._pos_derivative(*args, **parameters),
            orientation=self._orientation_derivative(*args, **parameters),
            velocity=self._vel_derivative(*args, **parameters),
            angular_velocity=self._ang_vel_derivative(*args, **parameters),
        )

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        force = self.input_ports[self.force_index].eval(context)
        torque = self.input_ports[self.torque_index].eval(context)

        with ErrorCollector.context(error_collector):
            if force.shape != (3,):
                raise ShapeMismatchError(
                    system=self,
                    expected_shape=(3,),
                    actual_shape=force.shape,
                )

            if torque.shape != (3,):
                raise ShapeMismatchError(
                    system=self,
                    expected_shape=(3,),
                    actual_shape=torque.shape,
                )

        if self.mass_index is not None:
            mass = self.input_ports[self.mass_index].eval(context)

            with ErrorCollector.context(error_collector):
                if mass.shape != ():
                    raise ShapeMismatchError(
                        system=self,
                        expected_shape=(),
                        actual_shape=mass.shape,
                    )

        if self.inertia_index is not None:
            inertia = self.input_ports[self.inertia_index].eval(context)

            with ErrorCollector.context(error_collector):
                if inertia.shape != (3, 3):
                    raise ShapeMismatchError(
                        system=self,
                        expected_shape=(3, 3),
                        actual_shape=inertia.shape,
                    )

Ros2Publisher

Bases: LeafSystem

Ros2Publisher block can emit signals to a ROS2 topic, based on input signal data.

Source code in collimator/library/ros2.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class Ros2Publisher(LeafSystem):
    """
    Ros2Publisher block can emit signals to a ROS2 topic, based on input signal data.
    """

    @parameters(static=["topic", "msg_type", "fields"])
    def __init__(
        self,
        dt: float,
        topic: str,
        msg_type: type,
        fields: dict[str, type],
        **kwargs,
    ):
        """
        Publish messages to a ROS2 topic.

        Args:
            dt: Period of the system, in both sim and real (ros2) time.
            topic: ROS2 topic to publish to. Eg. `/turtle1/cmd_vel`.
            msg_type: ROS2 message type, e.g. `Twist` from `geometry_msgs.msg`.
                      Unlike the corresponding UI parameter, this must be a Python
                      type object.
            fields: Ordered dictionary of default values to extract from the
                    received message. The keys are the full attribute path
                    (with dots) to the value in the message, and the values are
                    the default values. This is used to create the output ports
                    with valid data types. Use Python or Numpy data types, not JAX.

                    For instance, for a `geometry_msgs.msg.Twist` message, the
                    `fields` could be `{"linear.x": float, "angular.z": float}`.
        """

        super().__init__(**kwargs)
        self.logger = logger.getChild("Ros2Publisher:" + self.name_path_str)

        self.node: Node = None
        self.publisher: Publisher = None

        self.dt = dt
        self.topic = topic
        self.msg_type = msg_type
        self.fields = {field: _fixup_dtype(dtype) for field, dtype in fields.items()}

        self.declare_periodic_update(self._update, period=dt, offset=0.0)

        # Extract type & full attribute path from fields. Note that this
        # relies on the fact that the Python (3.7+) dict is ordered; The
        # order must match that of the input ports. Works well with JSON
        # because our I/O ports are ordered arrays.
        # This could likely be simplified / replaced with a Bus signal type.
        self.input_types = []  # [float, float]
        self.input_attr_path = []  # ["linear.x", "angular.z"]
        for msg_field_name, msg_field_type in self.fields.items():
            input_name = _attr2name(msg_field_name)
            self.declare_input_port(name=input_name)
            self.input_types.append(msg_field_type)
            self.input_attr_path.append(msg_field_name)

        self.pre_simulation_initialize()

    def __del__(self):
        self.post_simulation_finalize()

    def pre_simulation_initialize(self):
        if not _ros2_init():
            raise RuntimeError("ROS2 init failed")

        node_name = _NODE_NAME_REGEX.sub("_", self.name_path_str)
        rnd = np.random.randint(0, 1000)
        self.node = rclpy.create_node(f"collimator_{rnd}_" + node_name)
        self.publisher = self.node.create_publisher(
            self.msg_type, self.topic, qos_profile=10
        )

        self.logger.debug(
            "ROS2 publisher %s initialized with node: %s and publisher: %s",
            self.name_path_str,
            self.node,
            self.publisher,
        )

    def post_simulation_finalize(self) -> None:
        if self.node:
            self.logger.debug("ROS2 publisher %s clean up", self.name_path_str)
            self.node.destroy_publisher(self.publisher)
            self.publisher = None
            self.node.destroy_node()
            self.node = None
            _ros2_shutdown()

    def _update(self, time, state, *inputs, **params):
        return io_callback(self._publish_message, None, *inputs)

    def _publish_message(self, *inputs):
        msg = self.msg_type()

        for i, input_value in enumerate(inputs):
            value = self.input_types[i](input_value)
            _setattr_path(msg, self.input_attr_path[i], value)

        self.logger.debug("Publishing message to topic %s: %s", self.topic, msg)
        self.publisher.publish(msg)

        # Spin rclpy loop to ensure the message is sent. Also, sync the clocks
        # using dt. This is a bit of a hack for now until we have proper clock
        # synchronization.
        rclpy.spin_once(self.node, timeout_sec=self.dt)

__init__(dt, topic, msg_type, fields, **kwargs)

Publish messages to a ROS2 topic.

Parameters:

Name Type Description Default
dt float

Period of the system, in both sim and real (ros2) time.

required
topic str

ROS2 topic to publish to. Eg. /turtle1/cmd_vel.

required
msg_type type

ROS2 message type, e.g. Twist from geometry_msgs.msg. Unlike the corresponding UI parameter, this must be a Python type object.

required
fields dict[str, type]

Ordered dictionary of default values to extract from the received message. The keys are the full attribute path (with dots) to the value in the message, and the values are the default values. This is used to create the output ports with valid data types. Use Python or Numpy data types, not JAX.

For instance, for a `geometry_msgs.msg.Twist` message, the
`fields` could be `{"linear.x": float, "angular.z": float}`.
required
Source code in collimator/library/ros2.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@parameters(static=["topic", "msg_type", "fields"])
def __init__(
    self,
    dt: float,
    topic: str,
    msg_type: type,
    fields: dict[str, type],
    **kwargs,
):
    """
    Publish messages to a ROS2 topic.

    Args:
        dt: Period of the system, in both sim and real (ros2) time.
        topic: ROS2 topic to publish to. Eg. `/turtle1/cmd_vel`.
        msg_type: ROS2 message type, e.g. `Twist` from `geometry_msgs.msg`.
                  Unlike the corresponding UI parameter, this must be a Python
                  type object.
        fields: Ordered dictionary of default values to extract from the
                received message. The keys are the full attribute path
                (with dots) to the value in the message, and the values are
                the default values. This is used to create the output ports
                with valid data types. Use Python or Numpy data types, not JAX.

                For instance, for a `geometry_msgs.msg.Twist` message, the
                `fields` could be `{"linear.x": float, "angular.z": float}`.
    """

    super().__init__(**kwargs)
    self.logger = logger.getChild("Ros2Publisher:" + self.name_path_str)

    self.node: Node = None
    self.publisher: Publisher = None

    self.dt = dt
    self.topic = topic
    self.msg_type = msg_type
    self.fields = {field: _fixup_dtype(dtype) for field, dtype in fields.items()}

    self.declare_periodic_update(self._update, period=dt, offset=0.0)

    # Extract type & full attribute path from fields. Note that this
    # relies on the fact that the Python (3.7+) dict is ordered; The
    # order must match that of the input ports. Works well with JSON
    # because our I/O ports are ordered arrays.
    # This could likely be simplified / replaced with a Bus signal type.
    self.input_types = []  # [float, float]
    self.input_attr_path = []  # ["linear.x", "angular.z"]
    for msg_field_name, msg_field_type in self.fields.items():
        input_name = _attr2name(msg_field_name)
        self.declare_input_port(name=input_name)
        self.input_types.append(msg_field_type)
        self.input_attr_path.append(msg_field_name)

    self.pre_simulation_initialize()

Ros2Subscriber

Bases: LeafSystem

Ros2Subscriber block listens to messages over a ROS2 topic and outputs them as signals in collimator.

Source code in collimator/library/ros2.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
class Ros2Subscriber(LeafSystem):
    """
    Ros2Subscriber block listens to messages over a ROS2 topic and outputs them as
    signals in collimator.
    """

    @parameters(static=["topic", "msg_type", "fields", "read_before_start"])
    def __init__(
        self,
        dt,
        topic: str,
        msg_type: type,
        fields: dict[str, type],
        read_before_start=True,
        **kwargs,
    ):
        """Subscribe to a ROS2 topic and extract message values to output ports.

        Args:
            dt: Period of the system, in both sim and real (ros2) time.
            topic: ROS2 topic to subscribe to. Eg. `/turtle1/pose`.
            msg_type: ROS2 message type, e.g. `Pose` from `turtlesim.msg`.
                      Unlike the corresponding UI parameter, this must be a Python
                      type object.
            fields: Ordered dictionary of default values to extract from the
                    received message. The keys are the full attribute path
                    (with dots) to the value in the message, and the values are
                    the default values. This is used to create the output ports
                    with valid data types. Use Python or Numpy data types, not JAX.

                    For instance, for a `geometry_msgs.msg.Twist` message, the
                    `fields` could be `{"linear.x": float, "angular.z": float}`.
            read_before_start: If True, the subscriber will read the first message
                    before the simulation starts. Otherwise, the initial outputs will
                    be 0.
        """

        super().__init__(**kwargs)
        self.logger = logger.getChild("Ros2Subscriber:" + self.name_path_str)

        self.node: Node = None
        self.subscription: Subscription = None
        self._last_msg = None

        if not _ros2_init():
            raise RuntimeError("ROS2 init failed")

        self.dt = dt
        self.msg_type = msg_type
        self.topic = topic
        self.fields = {field: _fixup_dtype(dtype) for field, dtype in fields.items()}
        self.read_before_start = read_before_start

        # Note: Not 100% sure this is absolutely valid, but it worked with JAX.
        # If somehow we aren't getting updates, we may need to create a cache index,
        # see custom.py. See _callback().
        self.declare_periodic_update(self._update, period=dt, offset=0.0)
        self.default_values = {field: dtype() for field, dtype in self.fields.items()}

        def _make_output_cb(field_name: str, dtype: type):
            def _output():
                last_msg = self._last_msg or self.default_values
                value = _getattr_path(last_msg, field_name)
                return dtype(value)

            def _io_cb(time, state, *inputs, **params):
                return io_callback(_output, cnp.asarray(_output()))

            return _io_cb

        for field, dtype in self.fields.items():
            self.declare_output_port(
                callback=_make_output_cb(field, dtype),
                name=_attr2name(field),
                prerequisites_of_calc=[],
                requires_inputs=False,
                period=dt,
                offset=0.0,
                default_value=self.default_values[field],
            )

        self.pre_simulation_initialize()

    def __del__(self):
        self.post_simulation_finalize()

    def pre_simulation_initialize(self):
        if not _ros2_init():
            raise RuntimeError("ROS2 init failed")

        node_name = _NODE_NAME_REGEX.sub("_", self.name_path_str)
        rnd = np.random.randint(0, 1000)
        self.node = rclpy.create_node(f"collimator_{rnd}_" + node_name)
        self.subscription = self.node.create_subscription(
            self.msg_type, self.topic, self._callback, qos_profile=10
        )
        self.logger.debug(
            "ROS2 subscriber %s initialized, listening on topic %s msg_type=%s",
            self.name_path_str,
            self.topic,
            self.msg_type,
        )

        if self.read_before_start:
            self._update_cb()

    def post_simulation_finalize(self) -> None:
        if self.node:
            self.logger.debug("ROS2 subscriber %s clean up", self.name_path_str)
            self.node.destroy_subscription(self.subscription)
            self.subscription = None
            self.node.destroy_node()
            self.node = None
            _ros2_shutdown()

    def _update(self, time, state, *inputs, **params):
        return io_callback(self._update_cb, None)

    def _update_cb(self):
        # This timeout does not seem to block the call
        rclpy.spin_once(self.node, timeout_sec=2.0)

    def _callback(self, msg):
        self.logger.debug("Received message on topic %s: %s", self.topic, msg)

        # This may be wrong because we're not cleanly using the cache
        # like in custom.py. But it works.
        self._last_msg = msg

__init__(dt, topic, msg_type, fields, read_before_start=True, **kwargs)

Subscribe to a ROS2 topic and extract message values to output ports.

Parameters:

Name Type Description Default
dt

Period of the system, in both sim and real (ros2) time.

required
topic str

ROS2 topic to subscribe to. Eg. /turtle1/pose.

required
msg_type type

ROS2 message type, e.g. Pose from turtlesim.msg. Unlike the corresponding UI parameter, this must be a Python type object.

required
fields dict[str, type]

Ordered dictionary of default values to extract from the received message. The keys are the full attribute path (with dots) to the value in the message, and the values are the default values. This is used to create the output ports with valid data types. Use Python or Numpy data types, not JAX.

For instance, for a `geometry_msgs.msg.Twist` message, the
`fields` could be `{"linear.x": float, "angular.z": float}`.
required
read_before_start

If True, the subscriber will read the first message before the simulation starts. Otherwise, the initial outputs will be 0.

True
Source code in collimator/library/ros2.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
@parameters(static=["topic", "msg_type", "fields", "read_before_start"])
def __init__(
    self,
    dt,
    topic: str,
    msg_type: type,
    fields: dict[str, type],
    read_before_start=True,
    **kwargs,
):
    """Subscribe to a ROS2 topic and extract message values to output ports.

    Args:
        dt: Period of the system, in both sim and real (ros2) time.
        topic: ROS2 topic to subscribe to. Eg. `/turtle1/pose`.
        msg_type: ROS2 message type, e.g. `Pose` from `turtlesim.msg`.
                  Unlike the corresponding UI parameter, this must be a Python
                  type object.
        fields: Ordered dictionary of default values to extract from the
                received message. The keys are the full attribute path
                (with dots) to the value in the message, and the values are
                the default values. This is used to create the output ports
                with valid data types. Use Python or Numpy data types, not JAX.

                For instance, for a `geometry_msgs.msg.Twist` message, the
                `fields` could be `{"linear.x": float, "angular.z": float}`.
        read_before_start: If True, the subscriber will read the first message
                before the simulation starts. Otherwise, the initial outputs will
                be 0.
    """

    super().__init__(**kwargs)
    self.logger = logger.getChild("Ros2Subscriber:" + self.name_path_str)

    self.node: Node = None
    self.subscription: Subscription = None
    self._last_msg = None

    if not _ros2_init():
        raise RuntimeError("ROS2 init failed")

    self.dt = dt
    self.msg_type = msg_type
    self.topic = topic
    self.fields = {field: _fixup_dtype(dtype) for field, dtype in fields.items()}
    self.read_before_start = read_before_start

    # Note: Not 100% sure this is absolutely valid, but it worked with JAX.
    # If somehow we aren't getting updates, we may need to create a cache index,
    # see custom.py. See _callback().
    self.declare_periodic_update(self._update, period=dt, offset=0.0)
    self.default_values = {field: dtype() for field, dtype in self.fields.items()}

    def _make_output_cb(field_name: str, dtype: type):
        def _output():
            last_msg = self._last_msg or self.default_values
            value = _getattr_path(last_msg, field_name)
            return dtype(value)

        def _io_cb(time, state, *inputs, **params):
            return io_callback(_output, cnp.asarray(_output()))

        return _io_cb

    for field, dtype in self.fields.items():
        self.declare_output_port(
            callback=_make_output_cb(field, dtype),
            name=_attr2name(field),
            prerequisites_of_calc=[],
            requires_inputs=False,
            period=dt,
            offset=0.0,
            default_value=self.default_values[field],
        )

    self.pre_simulation_initialize()

Saturate

Bases: LeafSystem

Clip the input signal to a specified range.

Given an input signal u and upper and lower limits ulim and llim, the output signal is:

    y = max(llim, min(ulim, u))

where max and min are the element-wise maximum and minimum functions. This is equivalent to y = clip(u, llim, ulim).

Optionally, the block can also be configured with "dynamic" limits, which will add input ports for time-varying upper and lower limits.

Input ports

(0) The input signal. (1) The upper limit, if dynamic limits are enabled. (2) The lower limit, if dynamic limits are enabled. (Will be indexed as 1 if dynamic upper limits are not enabled.)

Output ports

(0) The clipped output signal.

Parameters:

Name Type Description Default
upper_limit

The upper limit of the input signal. Default is np.inf.

None
enable_dynamic_upper_limit

If True, then the upper limit can be set by an external signal. Default is False.

False
lower_limit

The lower limit of the input signal. Default is -np.inf.

None
enable_dynamic_lower_limit

If True, then the lower limit can be set by an external signal. Default is False.

False
Events

The block will trigger an event when the input signal crosses either the upper or lower limit. For example, if the block is configured with static upper and lower limits and the input signal crosses the upper limit, then a zero-crossing event will be triggered.

Source code in collimator/library/primitives.py
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
class Saturate(LeafSystem):
    """Clip the input signal to a specified range.

    Given an input signal `u` and upper and lower limits `ulim` and `llim`,
    the output signal is:
    ```
        y = max(llim, min(ulim, u))
    ```
    where `max` and `min` are the element-wise maximum and minimum functions.
    This is equivalent to `y = clip(u, llim, ulim)`.

    Optionally, the block can also be configured with "dynamic" limits, which will
    add input ports for time-varying upper and lower limits.

    Input ports:
        (0) The input signal.
        (1) The upper limit, if dynamic limits are enabled.
        (2) The lower limit, if dynamic limits are enabled. (Will be indexed as 1 if
            dynamic upper limits are not enabled.)

    Output ports:
        (0) The clipped output signal.

    Parameters:
        upper_limit:
            The upper limit of the input signal.  Default is `np.inf`.
        enable_dynamic_upper_limit:
            If True, then the upper limit can be set by an external signal. Default
            is False.
        lower_limit:
            The lower limit of the input signal.  Default is `-np.inf`.
        enable_dynamic_lower_limit:
            If True, then the lower limit can be set by an external signal. Default
            is False.

    Events:
        The block will trigger an event when the input signal crosses either the upper
        or lower limit.  For example, if the block is configured with static upper and
        lower limits and the input signal crosses the upper limit, then a zero-crossing
        event will be triggered.
    """

    @parameters(
        static=["enable_dynamic_upper_limit", "enable_dynamic_lower_limit"],
        dynamic=["upper_limit", "lower_limit"],
    )
    def __init__(
        self,
        upper_limit=None,
        enable_dynamic_upper_limit=False,
        lower_limit=None,
        enable_dynamic_lower_limit=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.primary_input_index = self.declare_input_port()
        self.enable_dynamic_upper_limit = enable_dynamic_upper_limit
        self.enable_dynamic_lower_limit = enable_dynamic_lower_limit

        prerequisites_of_calc = [self.input_ports[self.primary_input_index].ticket]

        if enable_dynamic_upper_limit:
            # If dynamic limit, simply ignore the static limit
            self.upper_limit_index = self.declare_input_port()
            prerequisites_of_calc.append(
                self.input_ports[self.upper_limit_index].ticket
            )
        else:
            if upper_limit is None:
                upper_limit = np.inf

        if enable_dynamic_lower_limit:
            # If dynamic limit, simply ignore the static limit
            self.lower_limit_index = self.declare_input_port()
            prerequisites_of_calc.append(
                self.input_ports[self.lower_limit_index].ticket
            )
        else:
            if lower_limit is None:
                lower_limit = -np.inf

        self.declare_output_port(
            self._compute_output, prerequisites_of_calc=prerequisites_of_calc
        )

    def initialize(
        self,
        upper_limit=None,
        enable_dynamic_upper_limit=False,
        lower_limit=None,
        enable_dynamic_lower_limit=False,
    ):
        if enable_dynamic_lower_limit != self.enable_dynamic_lower_limit:
            raise ValueError(
                "enable_dynamic_lower_limit must be the same as the value passed to the constructor"
            )
        if enable_dynamic_upper_limit != self.enable_dynamic_upper_limit:
            raise ValueError(
                "enable_dynamic_upper_limit must be the same as the value passed to the constructor"
            )

    def _lower_limit_event_value(self, _time, _state, *inputs, **params):
        u = inputs[self.primary_input_index]
        if self.enable_dynamic_lower_limit:
            lim = inputs[self.lower_limit_index]
        else:
            lim = params["lower_limit"]
        return u - lim

    def _upper_limit_event_value(self, _time, _state, *inputs, **params):
        u = inputs[self.primary_input_index]
        if self.enable_dynamic_upper_limit:
            lim = inputs[self.upper_limit_index]
        else:
            lim = params["upper_limit"]
        return u - lim

    def _compute_output(self, _time, _state, *inputs, **params):
        u = inputs[self.primary_input_index]

        ulim = (
            inputs[self.upper_limit_index]
            if self.enable_dynamic_upper_limit
            else params["upper_limit"]
        )
        llim = (
            inputs[self.lower_limit_index]
            if self.enable_dynamic_lower_limit
            else params["lower_limit"]
        )

        return cnp.clip(u, llim, ulim)

    def initialize_static_data(self, context):
        # Add zero-crossing events so ODE solvers can't try to integrate
        # through a discontinuity. For efficiency, only do this if the output
        # is fed to an ODE block
        if not self.has_zero_crossing_events and is_discontinuity(self.output_ports[0]):
            self.declare_zero_crossing(
                self._lower_limit_event_value,
                direction="positive_then_non_positive",
                name="llim",
            )
            self.declare_zero_crossing(
                self._upper_limit_event_value,
                direction="negative_then_non_negative",
                name="ulim",
            )

        return super().initialize_static_data(context)

Sawtooth

Bases: SourceBlock

Produces a modulated linear sawtooth signal.

The signal is similar to: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sawtooth.html

Given amplitude a, period p, and phase delay phi, the output signal is:

    y(t) = a * ((t - phi) % p)

where % is the modulo operator.

Input ports

None

Output ports

(0) The sawtooth signal.

Source code in collimator/library/primitives.py
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
class Sawtooth(SourceBlock):
    """Produces a modulated linear sawtooth signal.

    The signal is similar to:
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sawtooth.html

    Given amplitude `a`, period `p`, and phase delay `phi`, the output signal is:
    ```
        y(t) = a * ((t - phi) % p)
    ```
    where `%` is the modulo operator.

    Input ports:
        None

    Output ports:
        (0) The sawtooth signal.
    """

    # `frequency` is set as a static parameter because it reconfigures the periodic
    # update when initialize() is called which would break optimization and
    # ensemble because they don't re-create the context and therefore won't call
    # initialize() if `frequency` is updated.
    @parameters(dynamic=["amplitude", "phase_delay"], static=["frequency"])
    def __init__(self, amplitude=1.0, frequency=0.5, phase_delay=1.0, **kwargs):
        super().__init__(self._func, **kwargs)

        # Initialize the floating-point tolerance.  This will be machine epsilon
        # for the floating point type of the time variable (determined in the
        # static initialization step).
        self.eps = 0.0
        self._periodic_update_idx = self.declare_periodic_update()

    def initialize(self, amplitude, frequency, phase_delay):
        # Add a dummy event so that the ODE solver doesn't try to integrate through
        # the discontinuity.
        self.declare_discrete_state(default_value=False)

        self.period = 1 / frequency
        self.configure_periodic_update(
            self._periodic_update_idx,
            lambda *args, **kwargs: True,
            period=self.period,
            offset=phase_delay,
        )

    def _func(self, time, **parameters):
        # np.mod((t - phase_delay), (1.0 / frequency)) * amplitude
        period_fraction = cnp.mod(
            time - parameters["phase_delay"] + self.eps, self.period
        )
        return period_fraction * parameters["amplitude"]

    def initialize_static_data(self, context):
        # Determine machine epsilon for the type of the time variable
        self.eps = 2 * cnp.finfo(cnp.result_type(context.time)).eps
        return super().initialize_static_data(context)

ScalarBroadcast

Bases: FeedthroughBlock

Broadcast a scalar to a vector or matrix.

Given a scalar input u and dimensions m and n, this block will return a vector or matrix of shape (m, n) with all elements equal to u.

Input ports

(0) The scalar input signal.

Output ports

(0) The broadcasted output signal.

Parameters:

Name Type Description Default
m

The number of rows in the output matrix. If m is None, then the output will be a vector with shape (n,). To get a row vector of size (1,n), set m=1 expliclty.

required
n

The number of columns in the output matrix. If n is None, then the output will be a vector with shape (m,). To get a column vector of size (m,1), set n=1 expliclty.

required
Source code in collimator/library/primitives.py
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
class ScalarBroadcast(FeedthroughBlock):
    """Broadcast a scalar to a vector or matrix.

    Given a scalar input `u` and dimensions `m` and `n`, this block will return
    a vector or matrix of shape `(m, n)` with all elements equal to `u`.

    Input ports:
        (0) The scalar input signal.

    Output ports:
        (0) The broadcasted output signal.

    Parameters:
        m:
            The number of rows in the output matrix.  If `m` is None, then the output
            will be a vector with shape `(n,)`. To get a row vector of size `(1,n)`,
            set `m=1` expliclty.
        n:
            The number of columns in the output matrix.  If `n` is None, then the
            output will be a vector with shape `(m,)`. To get a column vector of size
            `(m,1)`, set `n=1` expliclty.
    """

    @parameters(static=["m", "n"])
    def __init__(self, m, n, **kwargs):
        super().__init__(None, **kwargs)

    def initialize(self, m, n):
        if m is not None:
            m = int(m)
        else:
            m = 0
        if n is not None:
            n = int(n)
        else:
            n = 0

        if m > 0 and n > 0:
            ones_ = cnp.ones((m, n))
        elif m > 0:
            ones_ = cnp.ones((m,))
        elif n > 0:
            ones_ = cnp.ones((n,))
        else:
            raise BlockParameterError(
                message=f"ScalarBroadcast block {self.name} at least m or n must not be None or Zero"
            )
        self.replace_op(lambda x: ones_ * x)

SignalDatatypeConversion

Bases: FeedthroughBlock

Convert the input signal to a different data type. Input ports: (0) The input signal. Output ports: (0) The input signal converted to the specified data type. Parameters: dtype: The data type to which the input signal is converted. Must be a valid NumPy data type, e.g. "float32", "int64", etc.

Source code in collimator/library/primitives.py
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
class SignalDatatypeConversion(FeedthroughBlock):
    """Convert the input signal to a different data type.
    Input ports:
        (0) The input signal.
    Output ports:
        (0) The input signal converted to the specified data type.
    Parameters:
        dtype:
            The data type to which the input signal is converted.  Must be a valid
            NumPy data type, e.g. "float32", "int64", etc.
    """

    def _op(self, dtype, x):
        # This check makes the numpy backend strict like jax
        if cnp.active_backend == "numpy" and isinstance(x, (list, tuple)):
            raise ValueError(
                "SignalDatatypeConversion block does not support list or tuple inputs."
            )

        return cond(
            isinstance(x, cnp.ndarray),
            lambda x: cnp.astype(x, dtype),
            lambda x: cnp.array(x, dtype),
            x,
        )

    @parameters(static=["convert_to_type"])
    def __init__(self, convert_to_type, *args, **kwargs):
        super().__init__(partial(self._op, np.dtype(convert_to_type)), *args, **kwargs)

    def initialize(self, convert_to_type):
        self.dtype = np.dtype(convert_to_type)
        self.replace_op(partial(self._op, np.dtype(convert_to_type)))

Sindy

Bases: LeafSystem

This class implements System Identification (SINDy) algorithm with or without control inputs for contiuous-time and discrete-time systems.

The learned continuous-time dynamical system model will be of the form:

dx/dt = f(x, u)

where x is the state vector and u is the optional control input vector. The block will output the full state vector x of the system.

The learned discrete-time dynamical system model will be of the form:

x_{k+1} = f(x_k, u_k)

where x_k is the state vector at time step k and u_k is the optional control vector. The block will update the output to x_k at an interval provided by the parameter discrete_time_update_interval.

Input ports

(0) u: control vector for the system. This port is only available if the Sindy model is trained with control inputs, i.e. control_input_columns is not None during training.

Output ports

(0) x: full state output of the system.

Parameters:

Name Type Description Default
file_name str

Path to the CSV file containing training data.

None
header_as_first_row bool

If True, the first row of the CSV file is treated as the header.

False
state_columns int | str | list[int] | list[str]

For training, either one of the following for CSV columns representing state variables x: - a string or integer (for a single column) - a list of strings or integers (for multiple columns) - a string representing a slice of columns, e.g. '0:3'

1
control_input_columns int | str | list[int] | list[str]

For training, either one of the following for CSV columns representing control inputs u: - a string or integer (for a single column) - a list of strings or integers (for multiple columns) - a string representing a slice of columns, e.g. '0:3' If None, then the SINDy model will be trained without control inputs.

None
dt float

Fixed value of dt if rows of the CSV file represent equidistant time steps.

None
time_column (str, int)

Column name (str) for column index (int) for time data t. If time_column is provided, then fixed dt above will be ignored. If neither dt nor time_column is provided, then the SINDy model will use a fixed detault time step of dt=1.

None
state_derivatives_columns int | str | list[int] | list[str]

For training, either one of the following for csv columns representing state derivatives x_dot: - a string or integer (for a single column) - a list of strings or integers (for multiple columns) - a string representing a slice of columns, e.g. '0:3' This field is optional. If provided, the SINDy model will estimate directly use these state derivatives for training. If not provided, the SINDy model will approximate the state derivatives dot_x = dx/dt from x by using the specified differentiation_method.

None
discrete_time bool

If True, the SINDy model will be trained for discrete-time systems. In this case, the dynamical system is treated as a map. Rather than predicting derivatives, the right hand side functions step the system forward by one time step. If False, dynamical system is assumed to be a flow (right-hand side functions predict continuous time derivatives). See documentation for pysindy.

False
differentiation_method str

Method to use for differentiating the state data x to obtain state derivatives dot_x = dx/dt. Available options are: 'centered difference' (default)

'centered difference'
threshold float

Threshold for the Sequentially thresholded least squares (STLSQ) algorithm used for training SINDy model.

0.1
alpha float

Regularization strength for the STLSQ algorithm.

0.05
max_iter int

Maximum number of iterations for the STLSQ algorithm.

20
normalize_columns bool

If True, normalize the columns of the data matrix before regression.

False
poly_order int

Degree of polynomial features. Set to None to omit this library.

2
fourier_n_frequencies int

Number of Fourier frequencies. Set to None to omit this library.

None
custom_basis_functions list of functions

A list of custom basis functions to use for training the SINDy model. For example to include f(x) = 1/x and g(x) = exp(-x), provide [lambda x: 1.0/(x.0 + 1e-06), lamda x: jnp.exp(-x)]

Currently only supported for pycollimator interface. Calls from UI and pretrained model loading does not support custom basis functions.

None
pretrained bool

If True, use a pretrained model specified by the pretrained_file_path argument.

False
pretrained_file_path str

Path to the pretrained model file.

None
initial_state ndarray
Initial state of the system for propagating the continuous-time
or discrete-time system forward duiring simulation.
None
discrete_time_update_interval float

Interval at which the discrete-time model should be updated. Default is 1.0.

1.0
equations list of strings

(For internal UI use only) The identified system equations.

None
base_feature_names list of strings

(For internal UI use only) Features x_i and u_i.

None
feature_names list of strings

(For internal UI use only) Composed features with basis libraries.

None
coefficients ndarray

(For internal UI use only) Coefficients of the identified model.

None
has_control_input bool

(For internal UI use only) If True, the model was trained with control. For standard training from CSV file, this is inferred from the parameter control_input_columns.

True
Source code in collimator/library/sindy.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
class Sindy(LeafSystem):
    """
    This class implements System Identification (SINDy) algorithm with or without
    control inputs for contiuous-time and discrete-time systems.

    The learned continuous-time dynamical system model will be of the form:

    ```
    dx/dt = f(x, u)
    ```
    where `x` is the state vector and `u` is the optional control input vector. The
    block will output the full state vector `x` of the system.

    The learned discrete-time dynamical system model will be of the form:

    ```
    x_{k+1} = f(x_k, u_k)
    ```

    where `x_k` is the state vector at time step `k` and `u_k` is the optional control
    vector. The block will update the output to `x_k` at an interval provided by the
    parameter `discrete_time_update_interval`.

    Input ports:
        (0) u: control vector for the system. This port is only available if the Sindy
            model is trained with control inputs, i.e. `control_input_columns` is not
            `None` during training.

    Output ports:
        (0) x: full state output of the system.

    Parameters:
        file_name (str):
            Path to the CSV file containing training data.

        header_as_first_row (bool):
            If True, the first row of the CSV file is treated as the header.

        state_columns (int | str | list[int] | list[str]):
            For training, either one of the following for CSV columns representing
            state variables `x`:
                - a string or integer (for a single column)
                - a list of strings or integers (for multiple columns)
                - a string representing a slice of columns, e.g. '0:3'


        control_input_columns (int | str | list[int] | list[str]):
            For training, either one of the following for CSV columns representing
            control inputs `u`:
                - a string or integer (for a single column)
                - a list of strings or integers (for multiple columns)
                - a string representing a slice of columns, e.g. '0:3'
            If None, then the SINDy model will be trained without control inputs.

        dt (float):
            Fixed value of dt if rows of the CSV file represent equidistant time steps.

        time_column (str, int):
            Column name (str) for column index (int) for time data `t`.
            If `time_column` is provided, then fixed `dt` above will be ignored.
            If neither `dt` nor `time_column` is provided, then the SINDy model will
            use a fixed detault time step of `dt=1`.

        state_derivatives_columns (int | str | list[int] | list[str]):
            For training, either one of the following for csv columns representing
            state derivatives `x_dot`:
                - a string or integer (for a single column)
                - a list of strings or integers (for multiple columns)
                - a string representing a slice of columns, e.g. '0:3'
            This field is optional. If provided, the SINDy model will estimate directly
            use these state derivatives for training. If not provided, the SINDy model
            will approximate the state derivatives `dot_x = dx/dt` from `x` by using
            the specified `differentiation_method`.

        discrete_time (bool):
            If True, the SINDy model will be trained for discrete-time systems. In
            this case, the dynamical system is treated as a map. Rather than
            predicting derivatives, the right hand side functions step the system
            forward by one time step. If False, dynamical system is assumed to be a
            flow (right-hand side functions predict continuous time derivatives).
            See documentation for `pysindy`.

        differentiation_method (str):
            Method to use for differentiating the state data `x` to obtain state
            derivatives `dot_x = dx/dt`. Available options are:
                'centered difference' (default)

        threshold (float):
            Threshold for the Sequentially thresholded least squares (STLSQ) algorithm
            used for training SINDy model.

        alpha (float):
            Regularization strength for the STLSQ algorithm.

        max_iter (int):
            Maximum number of iterations for the STLSQ algorithm.

        normalize_columns (bool):
            If True, normalize the columns of the data matrix before regression.

        poly_order (int):
            Degree of polynomial features. Set to `None` to omit this library.

        fourier_n_frequencies (int):
            Number of Fourier frequencies. Set to `None` to omit this library.

        custom_basis_functions (list of functions):
            A list of custom basis functions to use for training the SINDy model.
            For example to include `f(x) = 1/x` and `g(x) = exp(-x)`,
            provide [lambda x: 1.0/(x.0 + 1e-06), lamda x: jnp.exp(-x)]

            Currently only supported for pycollimator interface. Calls from UI and
            pretrained model loading does not support custom basis functions.

        pretrained (bool):
            If True, use a pretrained model specified by the `pretrained_file_path`
            argument.

        pretrained_file_path (str, optional): Path to the pretrained model file.

        initial_state (ndarray):
                Initial state of the system for propagating the continuous-time
                or discrete-time system forward duiring simulation.

        discrete_time_update_interval (float):
            Interval at which the discrete-time model should be updated. Default
            is 1.0.

        equations (list of strings):
            (For internal UI use only) The identified system equations.

        base_feature_names (list of strings):
            (For internal UI use only) Features x_i and u_i.

        feature_names (list of strings):
            (For internal UI use only) Composed features with basis libraries.

        coefficients (ndarray):
            (For internal UI use only) Coefficients of the identified model.

        has_control_input (bool):
            (For internal UI use only) If True, the model was trained with control.
            For standard training from CSV file, this is inferred from the
            parameter `control_input_columns`.
    """

    @parameters(
        static=[
            "file_name",
            "header_as_first_row",
            "state_columns",
            "control_input_columns",
            "discrete_time",
            "dt",
            "time_column",
            "state_derivatives_columns",
            "differentiation_method",
            "threshold",
            "alpha",
            "max_iter",
            "normalize_columns",
            "poly_order",
            "fourier_n_frequencies",
            "pretrained",
            "equations",
            "discrete_time_update_interval",
            "pretrained_file_path",
            "coefficients",
            "base_feature_names",
            "feature_names",
            "has_control_input",
            "initial_state",
        ],
    )
    def __init__(
        self,
        file_name=None,
        header_as_first_row=False,
        state_columns=1,
        control_input_columns=None,
        dt=None,
        time_column=None,
        state_derivatives_columns=None,
        discrete_time=False,
        differentiation_method="centered difference",
        # optimizer parameters
        threshold=0.1,
        alpha=0.05,
        max_iter=20,
        normalize_columns=False,
        # Library parameters
        poly_order=2,
        fourier_n_frequencies=None,
        custom_basis_functions=None,
        pretrained=False,
        pretrained_file_path=None,
        # for parameters obtained from UI training
        equations=None,
        base_feature_names=None,
        feature_names=None,
        coefficients=None,
        has_control_input=True,
        # Simulation parameters
        initial_state=None,
        discrete_time_update_interval=1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        _validate_leafsystem_inputs(
            pretrained,
            pretrained_file_path,
            dt,
            time_column,
            poly_order,
            fourier_n_frequencies,
        )

        ui_is_providing_pretrained_data = _validate_ui_pretrained_data(
            coefficients, feature_names, base_feature_names, self.name
        )

        if ui_is_providing_pretrained_data:
            self.equations = equations
            self.base_feature_names = base_feature_names.tolist()
            self.feature_names = feature_names.tolist()
            self.coefficients = cnp.array(coefficients)
            self.has_control_input = has_control_input
            self.custom_basis_functions = None

        elif pretrained:
            with open(pretrained_file_path, "r") as f:
                deserialized_model = json.load(f)

            self.equations = deserialized_model["equations"]
            self.base_feature_names = deserialized_model["base_feature_names"]
            self.feature_names = deserialized_model["feature_names"]
            self.coefficients = cnp.array(deserialized_model["coefficients"])
            self.has_control_input = deserialized_model["has_control_input"]
            self.custom_basis_functions = None

        else:
            (
                self.equations,
                self.base_feature_names,
                self.feature_names,
                self.coefficients,
                self.has_control_input,
            ) = train_from_csv(
                file_name,
                header_as_first_row=header_as_first_row,
                state_columns=state_columns,
                control_input_columns=control_input_columns,
                dt=dt,
                time_column=time_column,
                state_derivatives_columns=state_derivatives_columns,
                discrete_time=discrete_time,
                differentiation_method=differentiation_method,
                threshold=threshold,
                alpha=alpha,
                max_iter=max_iter,
                normalize_columns=normalize_columns,
                poly_order=poly_order,
                custom_basis_functions=custom_basis_functions,
                fourier_n_frequencies=fourier_n_frequencies,
            )
            self.custom_basis_functions = custom_basis_functions

        self.nx, _ = self.coefficients.shape

        if cnp.all(self.coefficients == 0):
            warnings.warn(
                "No features were selected for the SINDy model. "
                "Please check the training data and the feature selection "
                "parameters."
            )

        if initial_state is not None:
            if len(initial_state) != self.nx:
                raise ValueError(
                    f"Provided initial state has {len(initial_state)} elements. "
                    f"Expected {self.nx} elements."
                )
        else:
            initial_state = cnp.zeros(self.nx)

        if self.has_control_input:
            self.declare_input_port()  # one vector valued input port for u

        if discrete_time:
            self.declare_discrete_state(
                shape=(self.nx,),
                default_value=initial_state,
                as_array=True,
            )
            self.declare_periodic_update(
                (
                    self._discrete_update
                    if self.has_control_input
                    else lambda time, state, **params: self._discrete_update(
                        time, state, (), **params
                    )
                ),
                period=discrete_time_update_interval,
                offset=0.0,
            )
            self.declare_output_port(
                self._full_discrete_state_output,
                period=discrete_time_update_interval,
                offset=0.0,
                default_value=initial_state,
                requires_inputs=False,
            )

        else:
            self.declare_continuous_state(
                ode=(
                    self._ode
                    if self.has_control_input
                    else lambda time, state, **params: self._ode(
                        time, state, (), **params
                    )
                ),
                shape=(self.nx,),
                default_value=cnp.array(initial_state),
            )
            self.declare_continuous_state_output()  # output of the state in ODE

        # SymPy parsing to compute $f(x,u)$
        # For continuous-time systems $\dot{x} = f(x,u)$
        # For discrete-time systems $x_{k+1} = f(x_k, u_k)$
        sympy_base_features = sp.symbols(self.base_feature_names)

        # Convert feature names to sympy expressions
        sympy_feature_expressions = []
        for name in self.feature_names:
            # Replace spaces with multiplication
            name = name.replace(" ", "*")
            expr = sp.sympify(name)
            sympy_feature_expressions.append(expr)

        x_and_u_vec = sp.Matrix(sympy_base_features)
        custom_functions_dict = (
            {f"f{idx}": func for idx, func in enumerate(self.custom_basis_functions)}
            if self.custom_basis_functions
            else None
        )
        self.features_func = sp.lambdify(
            (x_and_u_vec,),
            sympy_feature_expressions,
            modules=[custom_functions_dict, "jax"] if custom_functions_dict else "jax",
        )  # feature functions

    def _ode(self, _time, state, inputs, **_params):
        """
        The ODE system RHS. The RHS is given by `coefficients @ features`
        """
        x = state.continuous_state
        u = inputs
        x_and_u = cnp.hstack([x, u])
        features_evaluated = self.features_func(x_and_u)
        x_dot = cnp.matmul(self.coefficients, cnp.atleast_1d(features_evaluated))
        return x_dot

    def _discrete_update(self, _time, state, inputs, **_params):
        """
        Update map is given by `coefficients @ features`
        """
        x = state.discrete_state
        u = inputs
        x_and_u = cnp.hstack([x, u])
        features_evaluated = self.features_func(x_and_u)
        x_plus = cnp.matmul(self.coefficients, cnp.atleast_1d(features_evaluated))
        return x_plus

    def _full_discrete_state_output(self, _time, state, *_inputs, **_params):
        return state.discrete_state

    def serialize(self, filename):
        """
        Save the relevant class attributes post training
        so that model state can be restored
        """
        sindy_data = {
            "equations": self.equations,
            "base_feature_names": self.base_feature_names,
            "feature_names": self.feature_names,
            "coefficients": self.coefficients.tolist(),  # Can't serialize numpy arrays
            "has_control_input": self.has_control_input,
        }
        with open(filename, "w") as f:
            json.dump(sindy_data, f)

    @staticmethod
    def serialize_trained_pysindy_model(model, filename):
        """
        Serialize a PySindy model trained outside of Collimator.
        The saved file can be used as a pretrained model in Collimator.
        """

        feature_names, coefficients = _reduce(
            model.feature_names, model.get_feature_names(), model.coefficients()
        )

        has_control_input = model.model.n_input_features_ > 0

        sindy_data = {
            "equations": model.equations,
            "base_feature_names": model.feature_names,
            "feature_names": feature_names,
            "coefficients": coefficients.tolist(),  # Can't serialize numpy arrays
            "has_control_input": has_control_input,
        }

        with open(filename, "w") as f:
            json.dump(sindy_data, f)

serialize(filename)

Save the relevant class attributes post training so that model state can be restored

Source code in collimator/library/sindy.py
621
622
623
624
625
626
627
628
629
630
631
632
633
634
def serialize(self, filename):
    """
    Save the relevant class attributes post training
    so that model state can be restored
    """
    sindy_data = {
        "equations": self.equations,
        "base_feature_names": self.base_feature_names,
        "feature_names": self.feature_names,
        "coefficients": self.coefficients.tolist(),  # Can't serialize numpy arrays
        "has_control_input": self.has_control_input,
    }
    with open(filename, "w") as f:
        json.dump(sindy_data, f)

serialize_trained_pysindy_model(model, filename) staticmethod

Serialize a PySindy model trained outside of Collimator. The saved file can be used as a pretrained model in Collimator.

Source code in collimator/library/sindy.py
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
@staticmethod
def serialize_trained_pysindy_model(model, filename):
    """
    Serialize a PySindy model trained outside of Collimator.
    The saved file can be used as a pretrained model in Collimator.
    """

    feature_names, coefficients = _reduce(
        model.feature_names, model.get_feature_names(), model.coefficients()
    )

    has_control_input = model.model.n_input_features_ > 0

    sindy_data = {
        "equations": model.equations,
        "base_feature_names": model.feature_names,
        "feature_names": feature_names,
        "coefficients": coefficients.tolist(),  # Can't serialize numpy arrays
        "has_control_input": has_control_input,
    }

    with open(filename, "w") as f:
        json.dump(sindy_data, f)

Sine

Bases: SourceBlock

Generates a sinusoidal signal.

Given amplitude a, frequency f, phase phi, and bias b, the output signal is:

    y(t) = a * sin(f * t + phi) + b
Input ports

None

Output ports

(0) The sinusoidal signal.

Parameters:

Name Type Description Default
amplitude

The amplitude of the sinusoidal signal.

1.0
frequency

The frequency of the sinusoidal signal.

1.0
phase

The phase of the sinusoidal signal.

0.0
bias

The bias of the sinusoidal signal.

0.0
Source code in collimator/library/primitives.py
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
class Sine(SourceBlock):
    """Generates a sinusoidal signal.

    Given amplitude `a`, frequency `f`, phase `phi`, and bias `b`, the output signal is:
    ```
        y(t) = a * sin(f * t + phi) + b
    ```

    Input ports:
        None

    Output ports:
        (0) The sinusoidal signal.

    Parameters:
        amplitude:
            The amplitude of the sinusoidal signal.
        frequency:
            The frequency of the sinusoidal signal.
        phase:
            The phase of the sinusoidal signal.
        bias:
            The bias of the sinusoidal signal.
    """

    @parameters(dynamic=["amplitude", "frequency", "phase", "bias"])
    def __init__(self, amplitude=1.0, frequency=1.0, phase=0.0, bias=0.0, **kwargs):
        super().__init__(self._eval, **kwargs)

    def initialize(self, amplitude=1.0, frequency=1.0, phase=0.0, bias=0.0):
        pass

    def _eval(self, t, **parameters):
        a = parameters["amplitude"]
        f = parameters["frequency"]
        phi = parameters["phase"]
        b = parameters["bias"]
        return a * cnp.sin(f * t + phi) + b

Slice

Bases: FeedthroughBlock

Slice the input signal using Python indexing rules.

Input ports

(0) The input signal.

Output ports

(0) The sliced output signal.

Parameters:

Name Type Description Default
slice_

The slice operator to apply to the input signal. Must be specified as a string input, e.g. the output u[1:3] would be created with the block Slice("1:3").

required
Notes

Currently only up to 3-dimensional slices are supported.

Source code in collimator/library/primitives.py
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
class Slice(FeedthroughBlock):
    """Slice the input signal using Python indexing rules.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The sliced output signal.

    Parameters:
        slice_:
            The slice operator to apply to the input signal.  Must be specified as a
            string input, e.g. the output `u[1:3]` would be created with the block
            `Slice("1:3")`.

    Notes:
        Currently only up to 3-dimensional slices are supported.
    """

    @parameters(static=["slice_"])
    def __init__(self, slice_, *args, **kwargs):
        super().__init__(None, *args, **kwargs)

    def initialize(self, slice_):
        # if slice was provided as numpy slice object, remove this before validating.
        if slice_.startswith("np.s_"):
            slice_ = slice_[len("np.s_") :]
        # if slice is wrapped in [], remove them temporarily.
        if slice_[0] == "[":
            slice_ = slice_[1:]
        if slice_[-1] == "]":
            slice_ = slice_[:-1]

        # validate slice_ and ensure no nefarious code.
        pattern = re.compile(r"^[0-9,:]+$")
        if not pattern.match(slice_):
            raise BlockParameterError(
                message=f"Slice block {self.name} detected invalid slice operator {slice_}. [] are optional. Valid examples: '1:3,4', '[:,4:10]'",
                parameter_name="slice_",
            )

        # replace the [] and eval to numpy slcie object
        slice_ = "np.s_[" + slice_ + "]"
        np_slice = eval(slice_)

        def _func(inp):
            return cnp.array(inp)[np_slice]

        self.replace_op(_func)

SourceBlock

Bases: LeafSystem

Simple blocks with a single time-dependent output

Source code in collimator/library/generic.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class SourceBlock(LeafSystem):
    """Simple blocks with a single time-dependent output"""

    def __init__(self, func: Callable, **kwargs):
        """Create a source block with a time-dependent output.

        Args:
            func (Callable):
                A function of time and parameters that returns a single value.
                Signature should be `func(time, **parameters) -> Array`.
        """
        super().__init__(**kwargs)
        self._output_port_idx = self.declare_output_port(
            None,
            name="out_0",
            prerequisites_of_calc=[DependencyTicket.time],
            requires_inputs=False,
        )
        self.replace_op(func)

    def replace_op(self, func):
        def _callback(time, state, *inputs, **parameters):
            return func(time, **parameters)

        self.configure_output_port(
            self._output_port_idx,
            _callback,
            prerequisites_of_calc=[DependencyTicket.time],
            requires_inputs=False,
        )

__init__(func, **kwargs)

Create a source block with a time-dependent output.

Parameters:

Name Type Description Default
func Callable

A function of time and parameters that returns a single value. Signature should be func(time, **parameters) -> Array.

required
Source code in collimator/library/generic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, func: Callable, **kwargs):
    """Create a source block with a time-dependent output.

    Args:
        func (Callable):
            A function of time and parameters that returns a single value.
            Signature should be `func(time, **parameters) -> Array`.
    """
    super().__init__(**kwargs)
    self._output_port_idx = self.declare_output_port(
        None,
        name="out_0",
        prerequisites_of_calc=[DependencyTicket.time],
        requires_inputs=False,
    )
    self.replace_op(func)

SquareRoot

Bases: FeedthroughBlock

Compute the square root of the input signal.

Dispatches to jax.numpy.sqrt, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sqrt.html

Input ports

(0) The input signal.

Output ports

(0) The square root of the input signal.

Source code in collimator/library/primitives.py
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
class SquareRoot(FeedthroughBlock):
    """Compute the square root of the input signal.

    Dispatches to `jax.numpy.sqrt`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sqrt.html

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The square root of the input signal.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(cnp.sqrt, *args, **kwargs)

Stack

Bases: ReduceBlock

Stack the input signals into a single output signal along a new axis.

Dispatches to jax.numpy.stack, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.stack.html

Input ports

(0..n_in-1) The input signals.

Output ports

(0) The stacked output signal.

Parameters:

Name Type Description Default
axis

The axis along which the input signals are stacked. Default is 0.

0
Source code in collimator/library/primitives.py
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
class Stack(ReduceBlock):
    """Stack the input signals into a single output signal along a new axis.

    Dispatches to `jax.numpy.stack`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.stack.html

    Input ports:
        (0..n_in-1) The input signals.

    Output ports:
        (0) The stacked output signal.

    Parameters:
        axis:
            The axis along which the input signals are stacked.  Default is 0.
    """

    @parameters(static=["axis"])
    def __init__(self, n_in, axis=0, **kwargs):
        super().__init__(n_in, None, **kwargs)

    def initialize(self, axis):
        self.replace_op(partial(cnp.stack, axis=int(axis)))

StateMachine

Bases: LeafSystem

Finite State Machine similar to Mealy Machine. https://en.wikipedia.org/wiki/Mealy_machine

The state machine can be executed either periodically or by zero_crossings.

Each state as 0 or more exit transitions. These are prioritized such that when 2 exits are simultaneously valid, the higher priority is executed. It is not allowed for a state to have more than one exit transition with no guard. Guardless exits only make sense in the periodic case.

Each transitions may have 0 or more actions. Each action is a python statement that modifies the value of an output. When a transitions is executed (i.e. it's guard evaluates to true), its actions are then processed.

If 'time' is needed for guards or actions, pass 'time' in from clock block.

Whether executed periodically or by zero_crossings, the states are constant between transitions executions. In the zero_crossing case, all guards for transitions exiting the current state are continuously checked, and if any 'triggers', then the earlist point in time that any guard becomes true is determined, the actions of the earliest (and highest priority if multiple trigger simultaneously) guard are executed at that time, and the simulation continues afterwards.

Input ports

User specified.

Output ports

User specified.

Parameters:

Name Type Description Default
dt

Either Float or None. When not None, state machine is executed periodically. When None, the transitions are monitored by zero_crossing events.

None
accelerate_with_jax bool

Bool. When True, the actions and guards are JIT-compiled with JAX. Default is False.

False
Source code in collimator/library/state_machine.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
@parameters(static=["accelerate_with_jax"])
class StateMachine(LeafSystem):
    """Finite State Machine similar to Mealy Machine.
    https://en.wikipedia.org/wiki/Mealy_machine

    The state machine can be executed either periodically or by zero_crossings.

    Each state as 0 or more exit transitions. These are prioritized such that
    when 2 exits are simultaneously valid, the higher priority is executed.
    It is not allowed for a state to have more than one exit transition with no
    guard. Guardless exits only make sense in the periodic case.

    Each transitions may have 0 or more actions. Each action is a python
    statement that modifies the value of an output. When a transitions is executed
    (i.e. it's guard evaluates to true), its actions are then processed.

    If 'time' is needed for guards or actions, pass 'time' in from clock block.

    Whether executed periodically or by zero_crossings, the states are constant
    between transitions executions.
    In the zero_crossing case, all guards for transitions exiting the current
    state are continuously checked, and if any 'triggers', then the earlist point
    in time that any guard becomes true is determined, the actions of the earliest
    (and highest priority if multiple trigger simultaneously) guard are executed at
    that time, and the simulation continues afterwards.

    Input ports:
        User specified.

    Output ports:
        User specified.

    Parameters:
        dt:
            Either Float or None.
            When not None, state machine is executed periodically.
            When None, the transitions are monitored by zero_crossing
            events.
        accelerate_with_jax:
            Bool. When True, the actions and guards are JIT-compiled with JAX.
            Default is False.
    """

    def __init__(
        self,
        sm_data: StateMachineData,
        inputs: List[str] = None,  # [name]
        outputs: List[str] = None,  # [name]
        dt=None,
        time_mode: str = "agnostic",
        name: str = None,
        ui_id: str = None,
        accelerate_with_jax: bool = False,
        **kwargs,
    ):
        super().__init__(name=name, ui_id=ui_id)

        if time_mode not in ["discrete", "agnostic"]:
            raise BlockInitializationError(
                f"Invalid time mode '{time_mode}' for PythonScript block", system=self
            )

        if time_mode == "discrete" and dt is None:
            raise BlockInitializationError(
                "When in discrete time mode, dt is required for block", system=self
            )

        if cnp.active_backend == "numpy" and accelerate_with_jax:
            raise BlockInitializationError(
                "Must use JAX numerical backend when accelerate_with_jax=True",
                system=self,
            )

        try:
            sm_data = _validate_sm_data(sm_data)
        except ValueError as e:
            raise StaticError(message=str(e), system=self) from e

        self._accelerate_with_jax = accelerate_with_jax

        if accelerate_with_jax:
            # inputs to many jax functions are expected to be jnp.arrays of
            # same shape, so we pad the arrays to the same shapes.
            self._sm = sm_data.to_padded_arrays()

            self._guards = cnp.array(
                [
                    [t.guard_id for t in self._sm.states[idx].transitions]
                    for idx in self._sm.states.keys()
                ]
            )

            self._dst = cnp.array(
                [
                    [t.dst for t in self._sm.states[idx].transitions]
                    for idx in self._sm.states.keys()
                ]
            )

            self._actions = cnp.array(
                [
                    [t.action_ids for t in self._sm.states[idx].transitions]
                    for idx in self._sm.states.keys()
                ]
            )
        else:
            self._sm = sm_data

        self.time_mode = time_mode
        _is_periodic = time_mode == "discrete"

        if inputs is None:
            inputs = []
        if outputs is None:
            outputs = []
        elif isinstance(outputs, dict):
            outputs = list(outputs.keys())

        # delcare inputs
        self._input_names = inputs
        for name in inputs:
            self.declare_input_port(name)

        self._output_names = outputs

        # Create the default discrete state values
        self._create_discrete_state_type(include_state_idx=_is_periodic)
        default_values = self._create_initial_discrete_state(
            include_state_idx=_is_periodic
        )
        self.declare_discrete_state(default_value=default_values, as_array=False)

        # Declare output ports for each state variable
        def _make_output_callback(o_port_name):
            def _output(time, state, *inputs, **parameters):
                return getattr(state.discrete_state, o_port_name)

            return _output

        for o_port_name in outputs:
            self.declare_output_port(
                _make_output_callback(o_port_name),
                name=o_port_name,
                prerequisites_of_calc=[DependencyTicket.xd],
                requires_inputs=False,
            )

        if _is_periodic:
            # delcare the periodic update event
            self.declare_periodic_update(
                self._discrete_update,
                period=dt,
                offset=dt,
            )
        else:
            # wrap the callback generation so that they do not get overwritten
            # in subsequent calls to declare_zero_crossing()
            def _make_guard_callback(t):
                def _guard(_time, state, *inputs, **parameters):
                    # Inputs are in order of port declaration, so they match `self._input_names`
                    inputs = dict(zip(self._input_names, inputs))
                    # get the values of the outputs as they are presently.
                    outputs = state.discrete_state._asdict()
                    # we do this so that when a guard goes False-True,
                    # it creates a zero-crossing that can be localized in time.
                    g = cnp.where(
                        self._sm.registry.guards[t.guard_id](**inputs, **outputs),
                        1.0,
                        -1.0,
                    )
                    return g

                return _guard

            def _make_reset_callback(t):
                def _reset(_time, state, *inputs, **p):
                    # Inputs are in order of port declaration, so they match `self._input_names`
                    inputs = dict(zip(self._input_names, inputs))
                    # get the values of the outputs as they are presently.
                    outputs = state.discrete_state._asdict()
                    if self._accelerate_with_jax:
                        updated_outputs = self._exec_actions_jax(
                            t.action_ids, inputs, outputs
                        )
                    else:
                        updated_outputs = self._exec_actions(
                            t.action_ids, inputs, outputs
                        )
                    return state.with_discrete_state(
                        value=self.DiscreteStateType(**updated_outputs)
                    )

                return _reset

            # declare zero-crossing driven events and mode
            self.declare_default_mode(self._sm.initial_state)
            self.declare_mode_output()
            for st_idx, st in self._sm.states.items():
                for t in st.transitions:
                    self.declare_zero_crossing(
                        guard=_make_guard_callback(t),
                        reset_map=_make_reset_callback(t),
                        direction="negative_then_non_negative",  # we only care when the guard transitions False->True
                        start_mode=st_idx,
                        end_mode=t.dst,
                    )

    def _create_discrete_state_type(self, include_state_idx=True):
        if include_state_idx:
            # unique identifier for the state machine state variable.
            # FIXME: is there a better way? was not allowed to use leading underscore
            # in data class.
            st_name = "active_state_index"

            if st_name in self._input_names or st_name in self._output_names:
                msg = f"StateMachine {self.name} has port with same name as state {st_name}, this is not allowed."
                raise StaticError(message=msg, system=self)

            self._st_name = st_name

            attribs = [st_name] + self._output_names
        else:
            attribs = self._output_names
        # declare the discrete_state as a namedtuple
        self.DiscreteStateType = namedtuple("DiscreteStateType", attribs)

    def _create_initial_discrete_state(self, include_state_idx=True):
        # execute the entry point actions
        inputs = {n: None for n in self._input_names}  # FIXME: get the inputs
        outputs = {n: None for n in self._output_names}

        initial_outputs = self._exec_actions(self._sm.initial_actions, inputs, outputs)

        # check if any initial_outputs is NaN
        for k, v in initial_outputs.items():
            if np.any(np.isnan(v)):
                msg = (
                    "StateMachine has NaN values in the initial outputs. "
                    "Inputs can't be used in initial actions."
                )
                raise BlockInitializationError(message=msg, system=self)

        # enforce that all outputs have been initialized
        initialized_output_names = set(initial_outputs.keys())
        all_output_names = set(self._output_names)
        uninitialized_output_names = all_output_names.difference(
            initialized_output_names
        )
        if uninitialized_output_names:
            msg = f"StateMachine does not initialize the following output values in the entry point actions: {uninitialized_output_names}"
            raise BlockInitializationError(message=msg, system=self)

        # get and save the output dtype,shape for use in creating the jax.pure_callback
        self.output_port_params = {
            o_port_name: {"dtype": jnp.array(val).dtype, "shape": jnp.array(val).shape}
            for o_port_name, val in initial_outputs.items()
        }

        # prepare the initial state
        if include_state_idx:
            return self.DiscreteStateType(
                active_state_index=self._sm.initial_state,
                **initial_outputs,
            )

        return self.DiscreteStateType(**initial_outputs)

    def _filter_locals(self, local_env):
        # remove any bindings from locals that are not outputs.
        filtered_locals = {}
        for key, value in local_env.items():
            if key in self._output_names:
                filtered_locals[key] = value
        return filtered_locals

    def _exec_actions(self, action_ids, inputs, outputs):
        # execute actions, in context with inputs values, when done
        # all actions, filter out any variable bindings that do
        # not correspond to outputs, then repack as dict of jnp.arrays
        updated_outputs = {}
        for action_id in action_ids:
            if action_id == -1:  # padded actions are -1
                continue
            input_args = [inputs[k] for k in self._input_names]
            output_args = [outputs[k] for k in self._output_names]
            output = self._sm.registry.actions[action_id](*input_args, *output_args)
            updated_outputs.update(output)

        updated_outputs = self._filter_locals(updated_outputs)
        updated_outputs = {k: jnp.array(v) for k, v in updated_outputs.items()}
        return updated_outputs

    def _exec_actions_jax(self, action_ids, inputs, outputs):
        """Execute actions in a JAX-compatible way."""

        def _exec_action(action_id):
            input_args = [inputs[k] for k in self._input_names]
            output_args = [outputs[k] for k in self._output_names]
            return cnp.cond(
                action_id == -1,  # padded actions are -1
                lambda: ({k: v for k, v in outputs.items()}, True),
                lambda: (
                    cnp.switch(
                        action_id,
                        self._sm.registry.actions,
                        *input_args,
                        *output_args,
                    ),
                    False,
                ),
            )

        # TODO: cnp.vmap (implement numpy version)
        action_outputs = jax.vmap(_exec_action)(action_ids)

        def _accumulate_outputs(carry, outputs):
            output, is_pad = outputs
            update = cnp.cond(
                is_pad, lambda: carry, lambda: {k: v for k, v in output.items()}
            )
            carry.update(update)
            return carry, carry

        init = {**outputs}
        updated_outputs, _ = cnp.scan(_accumulate_outputs, init, action_outputs)

        updated_outputs = self._filter_locals(updated_outputs)

        return updated_outputs

    def _numpy_callback(self, present_state_index, inputs, outputs):
        """
        The concept here is to evaluate all possible exit transitions from
        the active state, and then just return the updated (state,output values)
        for the successful transition. In the case no transitions are successful,
        we just return the present state and presen_outputs. Since we have ordered
        the possible transitions in order of priority, executing the lowest index
        successful trasition is 'correct' behavior.

        jax-compatible version of this function is `_jax_callback`.
        """
        # get the active state index, and the possible exit transitions
        present_state_index = int(present_state_index)
        actv_trns = self._sm.states[present_state_index].transitions

        # evaluate the guard for each possible exit transition.
        evaluated_guards = [
            self._sm.registry.guards[transition.guard_id](**inputs, **outputs)
            for transition in actv_trns
        ]

        if np.any(evaluated_guards):
            actv_trn = actv_trns[evaluated_guards.index(True)]
            new_state = actv_trn.dst
            updated_outputs = self._exec_actions(actv_trn.action_ids, inputs, outputs)
            new_outputs = []
            for k in self._output_names:
                output = updated_outputs[k] if k in updated_outputs else outputs[k]
                new_outputs.append(np.array(output))
            retval = [np.array(new_state), new_outputs]
        else:
            outputs = [np.array(outputs[k]) for k in self._output_names]
            retval = [np.array(present_state_index), outputs]

        return retval

    def _jax_callback(self, present_state_index, inputs, outputs):
        """
        The concept here is to evaluate all possible exit transitions from
        the active state, and then just return the updated (state,output values)
        for the successful transition. In the case no transitions are successful,
        we just return the present state and present_outputs. Since we have ordered
        the possible transitions in order of priority, executing the lowest index
        successful trasition is 'correct' behavior.
        """

        active_guards = _choose(present_state_index, self._guards)

        input_args = [inputs[k] for k in self._input_names]
        output_args = [outputs[k] for k in self._output_names]
        # evaluate the guard for each possible exit transition.
        evaluated_guards = cnp.array(
            [
                cnp.switch(
                    guard_id,
                    self._sm.registry.guards,
                    *input_args,
                    *output_args,
                )
                for guard_id in active_guards
            ]
        )

        def on_true():
            # Find the first active transition where the guard is True
            active_dst = _choose(present_state_index, self._dst)
            active_actions = _choose(present_state_index, self._actions)

            if np.size(evaluated_guards) == 0:
                # no guards are True, so we return the present state and outputs
                # note that adding jnp.size(evaluated_guards) > 0 to the cnp.cond
                # still evaluates the true branch, so we need to check the size
                # here.
                new_outputs = [jnp.array(outputs[k]) for k in self._output_names]
                return cnp.array(present_state_index), new_outputs

            idx = cnp.argmax(
                evaluated_guards
            )  # TODO: check that it returns the first index (lowest priority)

            new_state = _choose(idx, active_dst)
            action_ids = _choose(idx, active_actions)

            updated_outputs = self._exec_actions_jax(action_ids, inputs, outputs)
            new_outputs = []
            for k in self._output_names:
                output = updated_outputs[k] if k in updated_outputs else outputs[k]
                new_outputs.append(cnp.array(output))
            return new_state.squeeze(), new_outputs

        def on_false():
            new_outputs = [jnp.array(outputs[k]) for k in self._output_names]
            return cnp.array(present_state_index), new_outputs

        return cnp.cond(
            cnp.any(evaluated_guards),
            on_true,
            on_false,
        )

    def _discrete_update(self, _time, state: LeafState, *inputs, **params):
        # persent state index
        actv_state = state.discrete_state.active_state_index

        # Inputs are in order of port declaration, so they match `self._input_names`
        inputs = dict(zip(self._input_names, inputs))

        # get the values of the outputs as they are presently.
        outputs = {
            key: value
            for key, value in state.discrete_state._asdict().items()
            if key not in {self._st_name}
        }

        if self._accelerate_with_jax:
            new_state, new_outputs = self._jax_callback(actv_state, inputs, outputs)
        else:
            # build jax.pure_callback result_shape_dtypes
            # its a nested list like: [actv_st, [outp0, outp1, ... outpN]]
            result_shape_dtypes = [jax.ShapeDtypeStruct((), jnp.int64)]  # actv_state
            result_shape_dtypes_outps = []
            for var in self._output_names:
                port = self.output_port_params[var]
                result_shape_dtypes_outps.append(
                    jax.ShapeDtypeStruct(port["shape"], np.dtype(port["dtype"]))
                )
            result_shape_dtypes.append(result_shape_dtypes_outps)

            if cnp.active_backend == "numpy":
                new_state, new_outputs = self._numpy_callback(
                    actv_state,
                    inputs,
                    outputs,
                )
            else:
                # TODO: implement jax.custom_jvp to raise useful error when trying
                # to differentiate when accelerate_with_jax is False
                new_state, new_outputs = jax.pure_callback(
                    self._numpy_callback,
                    result_shape_dtypes,
                    actv_state,
                    inputs,
                    outputs,
                )

        outputs_dict = {k: v for k, v in zip(self._output_names, new_outputs)}

        return self.DiscreteStateType(
            active_state_index=new_state,
            **outputs_dict,
        )

Step

Bases: SourceBlock

A step signal.

Given start value y0, end value y1, and step time t0, the output signal is:

    y(t) = y0 if t < t0 else y1
Input ports

None

Output ports

(0) The step signal.

Parameters:

Name Type Description Default
start_value

The value of the output signal before the step time.

0.0
end_value

The value of the output signal after the step time.

1.0
step_time

The time at which the step occurs.

1.0
Source code in collimator/library/primitives.py
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
class Step(SourceBlock):
    """A step signal.

    Given start value `y0`, end value `y1`, and step time `t0`, the
    output signal is:
    ```
        y(t) = y0 if t < t0 else y1
    ```

    Input ports:
        None

    Output ports:
        (0) The step signal.

    Parameters:
        start_value:
            The value of the output signal before the step time.
        end_value:
            The value of the output signal after the step time.
        step_time:
            The time at which the step occurs.
    """

    @parameters(dynamic=["start_value", "end_value"], static=["step_time"])
    def __init__(self, start_value=0.0, end_value=1.0, step_time=1.0, **kwargs):
        super().__init__(self._func, **kwargs)
        self._periodic_update_idx = self.declare_periodic_update()

    def initialize(self, start_value, end_value, step_time):
        # Add a dummy event so that the ODE solver doesn't try to integrate through
        # the discontinuity.
        self._step_time = step_time
        self.declare_discrete_state(default_value=False)
        self.configure_periodic_update(
            self._periodic_update_idx,
            lambda *args, **kwargs: True,
            period=np.inf,
            offset=step_time,
        )

    def _func(self, time, **parameters):
        return cnp.where(
            time >= self._step_time,
            parameters["end_value"],
            parameters["start_value"],
        )

Stop

Bases: LeafSystem

Stop the simulation early as soon as the input signal becomes True.

If the input signal changes as a result of a discrete update, the simulation will terminate the major step early (before advancing continuous time).

Input ports

(0): the boolean- or binary-valued termination signal

Output ports

None

Source code in collimator/library/primitives.py
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
class Stop(LeafSystem):
    """Stop the simulation early as soon as the input signal becomes True.

    If the input signal changes as a result of a discrete update, the simulation
    will terminate the major step early (before advancing continuous time).

    Input ports:
        (0): the boolean- or binary-valued termination signal

    Output ports:
        None
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.declare_input_port()

        self.declare_zero_crossing(
            guard=self._guard,
            direction="negative_then_non_negative",
            terminal=True,
        )

    def _guard(self, time, state, u, **p):
        return cnp.where(u, 1.0, -1.0)

SumOfElements

Bases: FeedthroughBlock

Compute the sum of the elements of the input signal.

Dispatches to jax.numpy.sum, so see the JAX docs for details: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sum.html

Input ports

(0) The input signal.

Output ports

(0) The sum of the elements of the input signal.

Source code in collimator/library/primitives.py
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
class SumOfElements(FeedthroughBlock):
    """Compute the sum of the elements of the input signal.

    Dispatches to `jax.numpy.sum`, so see the JAX docs for details:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sum.html

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The sum of the elements of the input signal.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(cnp.sum, *args, **kwargs)

TensorFlow

Bases: LeafSystem

Block to perform inference with a pre-trained TensorFlow SavedModel.

The input to the block should be of compatible type and shape expected by the TensorFlow model. For example, if the TensorFlow SavedModel model expects a tf.float32 tensor of shape (3, 224, 224), the input to the block should be a jax.numpy array of shape (3, 224, 224) of dtype jnp.float32.

For output types, if no casting is specified through the cast_outputs_to_dtype parameter, the output of the block will have the same dtype as the TensorFlow model output, but expressed as jax.numpy types. For example. if the TensorFlow model outputs a tf.float32 tensor, the output of the block will be a jax.numpy array of dtype jnp.float32.

If casting is specified through cast_outputs_to_dtype parameter, all the outputs, of the block will be casted to this specific jax.numpy dtype.

Input ports

(i) The ith input to the model.

Output ports

(j) The jth output of the model.

Parameters:

Name Type Description Default
file_name str

Path to the model file. This should be a .zip containing the SavedModel.

required
cast_outputs_to_dtype str

The dtype to cast all the outputs of the block to. Must correspond to a jax.numpy datatype. For example, "float32", "float64", "int32", "int64".

None
add_batch_dim_to_inputs bool

Whether to add a new first dimension to the inputs before evaluating the TorchScript or TensorFlow model. This is useful when the model expects a batch dimension.

False
Source code in collimator/library/predictor.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
class TensorFlow(LeafSystem):
    """
    Block to perform inference with a pre-trained TensorFlow SavedModel.

    The input to the block should be of compatible type and shape expected by
    the TensorFlow model. For example,  if the TensorFlow SavedModel model expects a
    `tf.float32` tensor of shape `(3, 224, 224)`, the input to the block should be a
    `jax.numpy` array of shape (3, 224, 224) of dtype `jnp.float32`.

    For output types, if no casting is specified through the `cast_outputs_to_dtype`
    parameter, the output of the block will have the same dtype as the
    TensorFlow model output, but expressed as `jax.numpy` types. For example. if the
    TensorFlow model outputs a `tf.float32` tensor, the output of the block will be
    a `jax.numpy` array of dtype `jnp.float32`.

    If casting is specified through `cast_outputs_to_dtype` parameter, all the outputs,
    of the block will be casted to this specific `jax.numpy` dtype.

    Input ports:
        (i) The ith input to the model.

    Output ports:
        (j) The jth output of the model.

    Parameters:
        file_name (str):
            Path to the model file. This should be a `.zip` containing the SavedModel.

        cast_outputs_to_dtype (str):
            The dtype to cast all the outputs of the block to. Must correspond to a
            `jax.numpy` datatype. For example, "float32", "float64", "int32", "int64".

        add_batch_dim_to_inputs (bool):
            Whether to add a new first dimension to the inputs before evaluating the
            TorchScript or TensorFlow model. This is useful when the model expects a
            batch dimension.
    """

    @parameters(
        static=[
            "file_name",
            "cast_outputs_to_dtype",
            "add_batch_dim_to_inputs",
        ]
    )
    def __init__(
        self,
        file_name,
        cast_outputs_to_dtype=None,
        add_batch_dim_to_inputs=False,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        model, num_inputs, num_outputs, num_args, kwargs_signature = self._load_model(
            file_name
        )

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs

        for _ in range(self.num_inputs):
            self.declare_input_port()

        def _make_output_callback(output_index):
            def _output_callback(time, state, *inputs, **params):
                outputs = self._evaluate_output(time, state, *inputs, **params)
                return outputs[output_index]

            return _output_callback

        for output_index in range(self.num_outputs):
            self.declare_output_port(
                _make_output_callback(output_index),
                requires_inputs=True,
            )

    def _load_model(self, file_name):
        _, ext = os.path.splitext(file_name)

        if ext == ".zip":
            with tempfile.TemporaryDirectory() as model_dir:
                with zipfile.ZipFile(file_name, "r") as zip_ref:
                    zip_ref.extractall(model_dir)

                model = tf.saved_model.load(model_dir)

            model = model.signatures["serving_default"]

            num_args = len(model.structured_input_signature[0])
            kwargs_signature = model.structured_input_signature[1]
            num_kwargs = len(kwargs_signature)

            num_inputs = num_args + num_kwargs
            num_outputs = len(model.structured_outputs)
        else:
            raise ValueError(f"Expected extension of file is `.zip`, but found {ext}")

        return model, num_inputs, num_outputs, num_args, kwargs_signature

    def initialize(
        self,
        file_name,
        cast_outputs_to_dtype=None,
        add_batch_dim_to_inputs=False,
    ):
        self.dtype_output = (
            getattr(jnp, cast_outputs_to_dtype)
            if cast_outputs_to_dtype is not None
            else None
        )

        self.add_batch_dim_to_inputs = add_batch_dim_to_inputs

        model, num_inputs, num_outputs, num_args, kwargs_signature = self._load_model(
            file_name
        )

        if self.num_inputs != num_inputs:
            raise ValueError("num_inputs can't be changed after initialization")
        if self.num_outputs != num_outputs:
            raise ValueError("num_outputs can't be changed after initialization")

        self.model = model
        self.num_args = num_args
        self.kwargs_signature = kwargs_signature
        self.num_kwargs = len(self.kwargs_signature)

    def initialize_static_data(self, context):
        """Infer the output shapes and dtypes of the ML model."""
        # If building as part of a subsystem, this may not be fully connected yet.
        # That's fine, as long as it is connected by root context creation time.
        # This probably isn't a good long-term solution:
        #   see https://collimator.atlassian.net/browse/WC-51
        try:
            inputs = self.collect_inputs(context)

            outputs_jax = self._pure_callback(*inputs)

            self.pure_callback_result_type = [
                jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_jax
            ]
        except UpstreamEvalError:
            logger.debug(
                "Predictor.initialize_static_data: UpstreamEvalError. "
                "Continuing without default value initialization."
            )
        return super().initialize_static_data(context)

    def _evaluate_output(self, time, state, *inputs, **params):
        return jax.pure_callback(
            self._pure_callback,
            self.pure_callback_result_type,
            *inputs,
        )

    def _pure_callback(self, *inputs):
        inputs_casted = [
            tf.convert_to_tensor(np.array(item), dtype=sig.dtype)
            for item, sig in zip(inputs, self.kwargs_signature.values())
        ]
        args_casted = inputs_casted[: self.num_args]

        # kwargs and outputs are reversed in the model signature, so reverse the
        # order again for alignment.
        kwargs_casted = dict(
            zip(reversed(self.kwargs_signature.keys()), inputs_casted[self.num_args :])
        )

        if self.add_batch_dim_to_inputs:
            args_casted = [tf.expand_dims(x, axis=0) for x in args_casted]
            kwargs_casted = {
                key: tf.expand_dims(value, axis=0)
                for key, value in kwargs_casted.items()
            }

        if self.num_args == 0:
            outputs_dict = self.model(**kwargs_casted)
        else:
            outputs_dict = self.model(*args_casted, **kwargs_casted)

        outputs_jax = (
            [jnp.array(x, self.dtype_output) for x in reversed(outputs_dict.values())]
            if self.dtype_output is not None
            else [jnp.array(x) for x in reversed(outputs_dict.values())]
        )
        return outputs_jax

initialize_static_data(context)

Infer the output shapes and dtypes of the ML model.

Source code in collimator/library/predictor.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def initialize_static_data(self, context):
    """Infer the output shapes and dtypes of the ML model."""
    # If building as part of a subsystem, this may not be fully connected yet.
    # That's fine, as long as it is connected by root context creation time.
    # This probably isn't a good long-term solution:
    #   see https://collimator.atlassian.net/browse/WC-51
    try:
        inputs = self.collect_inputs(context)

        outputs_jax = self._pure_callback(*inputs)

        self.pure_callback_result_type = [
            jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_jax
        ]
    except UpstreamEvalError:
        logger.debug(
            "Predictor.initialize_static_data: UpstreamEvalError. "
            "Continuing without default value initialization."
        )
    return super().initialize_static_data(context)

TransferFunction

Bases: LTISystem

Continuous-time LTI system specified as a transfer function.

The transfer function is converted to state-space form using scipy.signal.tf2ss. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.tf2ss.html

The resulting system will be in canonical controller form with matrices (A, B, C, D), which are then used to create an LTISystem. Note that this only supports single-input, single-output systems.

Input ports

(0) u: Input vector (scalar)

Output ports

(0) y: Output vector (scalar). Note that this is feedthrough from the input port iff D is nonzero.

Parameters:

Name Type Description Default
num

Numerator polynomial coefficients, in descending powers of s

required
den

Denominator polynomial coefficients, in descending powers of s

required
Source code in collimator/library/linear_system.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class TransferFunction(LTISystem):
    """Continuous-time LTI system specified as a transfer function.

    The transfer function is converted to state-space form using `scipy.signal.tf2ss`.
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.tf2ss.html

    The resulting system will be in canonical controller form with matrices
    (A, B, C, D), which are then used to create an LTISystem.  Note that this only
    supports single-input, single-output systems.

    Input ports:
        (0) u: Input vector (scalar)

    Output ports:
        (0) y: Output vector (scalar).  Note that this is feedthrough from the input
            port iff D is nonzero.

    Parameters:
        num: Numerator polynomial coefficients, in descending powers of s
        den: Denominator polynomial coefficients, in descending powers of s
    """

    # tf2ss is not implemented in jax.scipy.signal so num and den can't be
    # dynamic parameters.
    @parameters(static=["num", "den"])
    def __init__(self, num, den, *args, **kwargs):
        A, B, C, D = signal.tf2ss(num, den)
        self._num = num
        self._den = den
        super().__init__(A, B, C, D, *args, **kwargs)

    def _eval_output(self, time, state, *inputs, **params):
        _, _, self.C, self.D = signal.tf2ss(self._num, self._den)
        return self._eval_output_base(self.C, self.D, state, *inputs)

    def ode(self, time, state, u, **params):
        self.A, self.B, _, _ = signal.tf2ss(self._num, self._den)
        return super().ode(time, state, u, A=self.A, B=self.B)

    def initialize(self, num, den, **kwargs):
        A, B, C, D = signal.tf2ss(num, den)
        self._init_state(A, B, C, D)

TransferFunctionDiscrete

Bases: LTISystemDiscrete

Implements a Discrete Time Transfer Function.

https://en.wikipedia.org/wiki/Z-transform#Transfer_function

The resulting system will be in canonical controller form with matrices (A, B, C, D), which are then used to create an LTISystem. Note that this only supports single-input, single-output systems.

Input ports

(0) u[k]: Input vector (scalar)

Output ports

(0) y[k]: Output vector (scalar). Note that this is feedthrough from the input port if and only if D is nonzero.

Parameters:

Name Type Description Default
dt

Sampling period of the discrete system.

required
num

Numerator polynomial coefficients, in descending powers of z

required
den

Denominator polynomial coefficients, in descending powers of z

required
initialize_states

Initial state vector (default: 0)

None
Source code in collimator/library/linear_system.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
class TransferFunctionDiscrete(LTISystemDiscrete):
    """Implements a Discrete Time Transfer Function.

    https://en.wikipedia.org/wiki/Z-transform#Transfer_function

    The resulting system will be in canonical controller form with matrices
    (A, B, C, D), which are then used to create an LTISystem.  Note that this only
    supports single-input, single-output systems.

    Input ports:
        (0) u[k]: Input vector (scalar)

    Output ports:
        (0) y[k]: Output vector (scalar). Note that this is feedthrough from the input
            port if and only if D is nonzero.

    Parameters:
        dt:
            Sampling period of the discrete system.
        num:
            Numerator polynomial coefficients, in descending powers of z
        den:
            Denominator polynomial coefficients, in descending powers of z
        initialize_states:
            Initial state vector (default: 0)
    """

    # tf2ss is not implemented in jax.scipy.signal so num and den can't be
    # dynamic parameters.
    @parameters(static=["num", "den"])
    def __init__(self, dt, num, den, initialize_states=None, *args, **kwargs):
        A, B, C, D = signal.tf2ss(num, den)
        super().__init__(A, B, C, D, dt, initialize_states, *args, **kwargs)

    def _eval_output(self, time, state, *inputs, **params):
        return super()._eval_output(
            time, state, *inputs, A=self.A, B=self.B, C=self.C, D=self.D
        )

    def _update(self, time, state, u, **params):
        return super()._update(time, state, u, A=self.A, B=self.B)

    def initialize(self, num, den, **kwargs):
        A, B, C, D = signal.tf2ss(num, den)
        self._init_state(A, B, C, D)

Trigonometric

Bases: FeedthroughBlock

Apply a trigonometric function to the input signal.

Available functions are

sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh

Dispatches to jax.numpy.sin, jax.numpy.cos, etc, so see the JAX docs for details.

Input ports

(0) The input signal.

Output ports

(0) The trigonometric function applied to the input signal.

Parameters:

Name Type Description Default
function

The trigonometric function to apply to the input signal. Must be one of "sin", "cos", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh".

required
Source code in collimator/library/primitives.py
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
class Trigonometric(FeedthroughBlock):
    """Apply a trigonometric function to the input signal.

    Available functions are:
        sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh

    Dispatches to `jax.numpy.sin`, `jax.numpy.cos`, etc, so see the JAX docs for details.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The trigonometric function applied to the input signal.

    Parameters:
        function:
            The trigonometric function to apply to the input signal.  Must be one of
            "sin", "cos", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh",
            "asinh", "acosh", "atanh".
    """

    @parameters(static=["function"])
    def __init__(self, function, **kwargs):
        super().__init__(None, **kwargs)

    def initialize(self, function):
        func_lookup = {
            "sin": cnp.sin,
            "cos": cnp.cos,
            "tan": cnp.tan,
            "asin": cnp.arcsin,
            "acos": cnp.arccos,
            "atan": cnp.arctan,
            "sinh": cnp.sinh,
            "cosh": cnp.cosh,
            "tanh": cnp.tanh,
            "asinh": cnp.arcsinh,
            "acosh": cnp.arccosh,
            "atanh": cnp.arctanh,
        }
        if function not in func_lookup:
            raise BlockParameterError(
                message=f"Trigonometric block {self.name} has invalid selection {function} for 'function'. Valid options: "
                + ", ".join([f for f in func_lookup.keys()]),
                parameter_name="function",
            )
        self.replace_op(func_lookup[function])

UnitDelay

Bases: LeafSystem

Hold and delay the input signal by one time step.

This block implements a "unit delay" with the following difference equation for internal state x, input signal u, and output signal y:

    x[k+1] = u[k]
    y[k] = x[k]

Or, in a hybrid context, the discrete update advances the internal state from the "pre" or "minus" value x⁻ to the "post" or "plus" value x⁺ at time tₖ = t0 + k * dt. According to the discrete update rules, this calculation happens using the input values computed during the update step (i.e. by computing upstream outputs before evaluating the inputs to this block). That is, the update rule can be written x⁺(tₖ) = f(tₖ, x⁻(tₖ), u(tₖ)). The values of u are not distinguished as "pre" or "post" because there is only one value at the update time. In the difference equation notation, x⁺(tₖ) ≡ x[k+1],x⁻(tₖ) ≡ x[k], and u(tₖ) ≡ u[k]. The hybrid update rule is then:

    x⁺(tₖ) = u(tₖ)
    y(t) = x⁻(tₖ),       between tₖ⁺ and (tₖ+dt)⁻

The output signal "seen" by all other blocks on the time interval (tₖ, tₖ+dt) is then the value of the input signal u(tₖ) at the previous update. Therefore, all downstream discrete-time blocks updating at the same time tₖ will still see the value of x⁻(tₖ), the value of the internal state prior to the update.

Input ports

(0) The input signal.

Output ports

(0) The input signal delayed by one time step

Parameters:

Name Type Description Default
dt

The time step of the discrete update.

required
initial_state

The initial state of the block. Default is 0.0.

required
Source code in collimator/library/primitives.py
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
class UnitDelay(LeafSystem):
    """Hold and delay the input signal by one time step.

    This block implements a "unit delay" with the following difference equation
    for internal state `x`, input signal `u`, and output signal `y`:
    ```
        x[k+1] = u[k]
        y[k] = x[k]
    ```
    Or, in a hybrid context, the discrete update advances the internal state from
    the "pre" or "minus" value x⁻ to the "post" or "plus" value x⁺ at time
    `tₖ = t0 + k * dt`.  According to the discrete update rules, this calculation
    happens using the input values computed during the update step (i.e. by computing
    upstream outputs before evaluating the inputs to this block). That is, the update
    rule can be written `x⁺(tₖ) = f(tₖ, x⁻(tₖ), u(tₖ))`.  The values of `u` are not
    distinguished as "pre" or "post" because there is only one value at the update
    time.  In the difference equation notation, x⁺(tₖ) ≡ x[k+1]`, `x⁻(tₖ) ≡ x[k],
    and u(tₖ) ≡ u[k].  The hybrid update rule is then:
    ```
        x⁺(tₖ) = u(tₖ)
        y(t) = x⁻(tₖ),       between tₖ⁺ and (tₖ+dt)⁻
    ```

    The output signal "seen" by all other blocks on the time interval (tₖ, tₖ+dt)
    is then the value of the input signal u(tₖ) at the previous update. Therefore, all
    downstream discrete-time blocks updating at the same time tₖ will still see the
    value of x⁻(tₖ), the value of the internal state prior to the update.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The input signal delayed by one time step

    Parameters:
        dt:
            The time step of the discrete update.
        initial_state:
            The initial state of the block.  Default is 0.0.
    """

    @parameters(dynamic=["initial_state"])
    def __init__(self, dt, initial_state, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dt = dt
        self.declare_input_port()
        self._periodic_update_idx = self.declare_periodic_update()
        self._output_port_idx = self.declare_output_port()

    def initialize(self, initial_state):
        self.configure_periodic_update(
            self._periodic_update_idx, self._update, period=self.dt, offset=self.dt
        )

        self.configure_output_port(
            self._output_port_idx,
            self._output,
            period=self.dt,
            offset=0.0,
            requires_inputs=False,
            prerequisites_of_calc=[DependencyTicket.xd],
            default_value=initial_state,
        )

    def reset_default_values(self, initial_state):
        self.declare_discrete_state(default_value=initial_state)
        self.configure_output_port_default_value(self._output_port_idx, initial_state)

    def _update(self, _time, _state, u, **_params):
        # Every dt seconds, update the state to the current input value
        return u

    def _output(self, _time, state, **parameters):
        return state.discrete_state

    def check_types(
        self,
        context,
        error_collector: ErrorCollector = None,
    ):
        inp_data = self.eval_input(context)
        xd = context[self.system_id].discrete_state
        check_state_type(
            self,
            inp_data=inp_data,
            state_data=xd,
            error_collector=error_collector,
        )

UnscentedKalmanFilter

Bases: KalmanFilterBase

Unscented Kalman Filter (UKF) for the following system:

```
x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
y[n]   = g(x[n], u[n]) + v[n]

E(w[n]) = E(v[n]) = 0
E(w[n]w'[n]) = Q(t[n], x[n], u[n])
E(v[n]v'[n] = R(t[n])
E(w[n]v'[n] = N(t[n]) = 0
```

f and g are discrete-time functions of state x[n] and control u[n], while RandGare discrete-time functions of timet[n].Qis a discrete-time function oft[n], x[n], u[n]`. This last aspect is included for zero-order-hold discretization of a continuous-time system

Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
dt

float Time step of the discrete-time system

required
forward

Callable A function with signature f(x[n], u[n]) -> x[n+1] that represents f in the above equations.

required
observation

Callable A function with signature g(x[n], u[n]) -> y[n] that represents g in the above equations.

required
G_func

Callable A function with signature G(t[n]) -> G[n] that represents G in the above equations.

required
Q_func

Callable A function with signature Q(t[n], x[n], u[n]) -> Q[n] that represents Q in the above equations.

required
R_func

Callable A function with signature R(t[n]) -> R[n] that represents R in the above equations.

required
x_hat_0

ndarray Initial state estimate

required
P_hat_0

ndarray Initial state covariance matrix estimate

required
alpha

float Sigma point spread to control the amount of nonlinearities taken into account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.

1.0
beta

float Scaling constant to include prior information about the distribution of the state. Default is 0.0.

0.0
kappa

float Relatively non-critical parameter to control the kurtosis of sigma point distribution. Default is 0.0.

0.0
Source code in collimator/library/state_estimators/unscented_kalman_filter.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
class UnscentedKalmanFilter(KalmanFilterBase):
    """
    Unscented Kalman Filter (UKF) for the following system:

        ```
        x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
        y[n]   = g(x[n], u[n]) + v[n]

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Q(t[n], x[n], u[n])
        E(v[n]v'[n] = R(t[n])
        E(w[n]v'[n] = N(t[n]) = 0
        ```

    `f` and `g` are discrete-time functions of state `x[n]` and control `u[n]`,
    while R` and `G` are discrete-time functions of time `t[n]`. `Q` is a discrete-time
    function of `t[n], x[n], u[n]`. This last aspect is included for zero-order-hold
    discretization of a continuous-time system

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        dt: float
            Time step of the discrete-time system
        forward: Callable
            A function with signature f(x[n], u[n]) -> x[n+1] that represents `f` in
            the above equations.
        observation: Callable
            A function with signature g(x[n], u[n]) -> y[n] that represents `g` in
            the above equations.
        G_func: Callable
            A function with signature G(t[n]) -> G[n] that represents `G` in
            the above equations.
        Q_func: Callable
            A function with signature Q(t[n], x[n], u[n]) -> Q[n] that represents `Q`
            in the above equations.
        R_func: Callable
            A function with signature R(t[n]) -> R[n] that represents `R` in
            the above equations.
        x_hat_0: ndarray
            Initial state estimate
        P_hat_0: ndarray
            Initial state covariance matrix estimate
        alpha: float
            Sigma point spread to control the amount of nonlinearities taken into
            account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.
        beta: float
            Scaling constant to include prior information about the distribution of
            the state. Default is 0.0.
        kappa: float
            Relatively non-critical parameter to control the kurtosis of sigma point
            distribution. Default is 0.0.
    """

    @parameters(
        static=[
            "dt",
            "forward",
            "observation",
            "G_func",
            "Q_func",
            "R_func",
            "x_hat_0",
            "P_hat_0",
            "alpha",
            "beta",
            "kappa",
        ],
    )
    def __init__(
        self,
        dt,
        forward,
        observation,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
        alpha=1.0,
        beta=0.0,
        kappa=0.0,
        is_feedthrough=True,  # TODO: determine automatically?
        name=None,
        **kwargs,
    ):
        super().__init__(dt, x_hat_0, P_hat_0, is_feedthrough, name, **kwargs)

    def initialize(
        self,
        dt,
        forward,
        observation,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
        alpha=1.0,
        beta=0.0,
        kappa=0.0,
    ):
        self.G_func = G_func
        self.Q_func = Q_func
        self.R_func = R_func

        self.nx = x_hat_0.size
        self.ny = self.R_func(0.0).shape[0]

        self.alpha = alpha
        self.beta = beta
        self.kappa = kappa

        self.forward = forward
        self.observation = observation

        self.forward_sigma_points = jax.vmap(forward, in_axes=(0, None))
        self.observation_sigma_points = jax.vmap(observation, in_axes=(0, None))

        self.num_sigma_points = 2 * self.nx + 1
        self.lamb = (self.alpha**2.0) * (self.nx + self.kappa) - self.nx
        self.lamb_plus_nx = self.lamb + self.nx

        self.weights_mean = jnp.full(2 * self.nx + 1, 0.5 / (self.lamb + self.nx))
        self.weights_mean = self.weights_mean.at[0].set(
            self.lamb / (self.lamb + self.nx)
        )

        self.weights_cov = jnp.full(2 * self.nx + 1, 0.5 / (self.lamb + self.nx))
        self.weights_cov = self.weights_cov.at[0].set(
            self.lamb / (self.lamb + self.nx) + (1.0 - alpha**2 + beta)
        )

    def _gen_sigma_points(self, mean, cov):
        chol_cov = jsp.linalg.cholesky(
            self.lamb_plus_nx * cov
        )  # upper triangular Cholesky fact.

        sigma_points_plus = mean + chol_cov
        sigma_points_minus = mean - chol_cov

        sigma_points = jnp.vstack([mean, sigma_points_plus, sigma_points_minus])

        return sigma_points

    def _get_weighted_mean_and_cov_from_sigma_points(self, sigma_points):
        mean = jnp.dot(self.weights_mean, sigma_points)
        delta_sigma_points = sigma_points - mean
        cov = delta_sigma_points.T @ jnp.diag(self.weights_cov) @ delta_sigma_points

        return mean, cov

    def _get_weighted_cross_covariance_from_sigma_points(
        self, sigma_points_x, sigma_points_y
    ):
        mean_x = jnp.dot(self.weights_mean, sigma_points_x)
        delta_sigma_points_x = sigma_points_x - mean_x

        mean_y = jnp.dot(self.weights_mean, sigma_points_y)
        delta_sigma_points_y = sigma_points_y - mean_y

        cov_xy = (
            delta_sigma_points_x.T @ jnp.diag(self.weights_cov) @ delta_sigma_points_y
        )

        return cov_xy

    def _correct(self, time, x_hat_minus, P_hat_minus, *inputs):
        u, y = inputs
        u = jnp.atleast_1d(u)
        y = jnp.atleast_1d(y)

        sigma_points_x_minus = self._gen_sigma_points(x_hat_minus, P_hat_minus).reshape(
            (self.num_sigma_points, self.nx)
        )

        sigma_points_y_minus = self.observation_sigma_points(
            sigma_points_x_minus, u
        ).reshape((self.num_sigma_points, self.ny))

        y_mean, Py = self._get_weighted_mean_and_cov_from_sigma_points(
            sigma_points_y_minus
        )

        Pxy = self._get_weighted_cross_covariance_from_sigma_points(
            sigma_points_x_minus,
            sigma_points_y_minus,
        )

        R = self.R_func(time)
        S = Py + R

        # TODO: improved numerics to avoud computing explicit inverse
        K = jnp.matmul(Pxy, jnp.linalg.inv(S))

        x_hat_plus = x_hat_minus + jnp.dot(K, y - y_mean)  # n|n
        P_hat_plus = P_hat_minus - K @ S @ K.T  # n|n

        return x_hat_plus, P_hat_plus

    def _propagate(self, time, x_hat_plus, P_hat_plus, *inputs):
        # Predict -- x_hat_plus of current step is propagated to be the
        # x_hat_minus of the next step
        # k+1|k in current step is k|k-1 for next step

        u, y = inputs
        u = jnp.atleast_1d(u)

        G = self.G_func(time)
        Q = self.Q_func(time, x_hat_plus, u)
        GQGT = G @ Q @ G.T

        sigma_points_x_plus = self._gen_sigma_points(x_hat_plus, P_hat_plus).reshape(
            (self.num_sigma_points, self.nx)
        )

        sigma_points_x_minus = self.forward_sigma_points(
            sigma_points_x_plus, u
        ).reshape((self.num_sigma_points, self.nx))

        x_hat_minus, Px = self._get_weighted_mean_and_cov_from_sigma_points(
            sigma_points_x_minus
        )  # n+1|n

        P_hat_minus = Px + GQGT  # n+1|n

        return x_hat_minus, P_hat_minus

    #######################################
    # Make filter for a continuous plant  #
    #######################################

    @staticmethod
    @with_resolved_parameters
    def for_continuous_plant(
        plant,
        dt,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
        discretization_method="euler",
        discretized_noise=False,
        alpha=1.0,
        beta=0.0,
        kappa=0.0,
        name=None,
        ui_id=None,
    ):
        """
        Unscented Kalman Filter system for a continuous-time plant.

        The input plant contains the deterministic forms of the forward and observation
        operators:

        ```
            dx/dt = f(x,u)
            y = g(x,u)
        ```

        Note: (i) Only plants with one vector-valued input and one vector-valued output
        are currently supported. Furthermore, the plant LeafSystem/Diagram should have
        only one vector-valued integrator; (ii) the user may pass a plant with
        disturbances (not recommended) as the input plant. In this case, the forward
        and observation evaluations will be corrupted by noise.

        A plant with disturbances of the following form is then considered:

        ```
            dx/dt = f(x,u) + G(t) w         -- (C1)
            y = g(x,u) +  v                 -- (C2)
        ```

        where:

            `w` represents the process noise,
            `v` represents the measurement noise,

        and

        ```
            E(w) = E(v) = 0
            E(ww') = Q(t)
            E(vv') = R(t)
            E(wv') = N(t) = 0
        ```

        This plant is discretized to obtain the following form:

        ```
            x[n+1] = fd(x[n], u[n]) + Gd w[n]  -- (D1)
            y[n]   = gd(x[n], u[n]) + v[n]     -- (D2)

            E(w[n]) = E(v[n]) = 0
            E(w[n]w'[n]) = Qd
            E(v[n]v'[n] = Rd
            E(w[n]v'[n] = Nd = 0
        ```

        The above discretization is performed either via the `euler` or the `zoh`
        method, and an Unscented Kalman Filter estimator for the system of equations
        (D1) and (D2) is returned.

        Note: If `discretized_noise` is True, then it is assumed that the user is
        directly providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
        continuous-time Q, R, and G, and Gd is set to an Identity matrix.

        The returned system will have:

        Input ports:
            (0) u[n] : control vector at timestep n
            (1) y[n] : measurement vector at timestep n

        Output ports:
            (1) x_hat[n] : state vector estimate at timestep n

        Parameters:
            plant : a `Plant` object which can be a LeafSystem or a Diagram.
            dt: float
                Time step for the discretization.
            G_func: Callable
                A function with signature G(t) -> G that represents `G` in
                the continuous-time equations (C1) and (C2).
            Q_func: Callable
                A function with signature Q(t) -> Q that represents `Q` in
                the continuous-time equations (C1) and (C2).
            R_func: Callable
                A function with signature R(t) -> R that represents `R` in
                the continuous-time equations (C1) and (C2).
            x_hat_0: ndarray
                Initial state estimate
            P_hat_0: ndarray
                Initial state covariance matrix estimate. If `None`, an Identity
                matrix is assumed.
            discretization_method: str ("euler" or "zoh")
                Method to discretize the continuous-time plant. Default is "euler".
            discretized_noise: bool
                Whether the user is directly providing Gd, Qd and Rd. Default is False.
                If True, `G_func`, `Q_func`, and `R_func` provide Gd(t), Qd(t), and
                Rd(t), respectively.
            alpha: float
                Sigma point spread to control the amount of nonlinearities taken into
                account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.
            beta: float
                Scaling constant to include prior information about the distribution of
                the state. Default is 0.0.
            kappa: float
                Relatively non-critical parameter to control the kurtosis of sigma
                point distribution. Default is 0.0.
        """

        (
            forward,
            observation,
            Gd_func,
            Qd_func,
            Rd_func,
        ) = prepare_continuous_plant_for_nonlinear_kalman_filter(
            plant,
            dt,
            G_func,
            Q_func,
            R_func,
            x_hat_0,
            discretization_method,
            discretized_noise,
        )

        nx = x_hat_0.size
        if P_hat_0 is None:
            P_hat_0 = jnp.eye(nx)

        # TODO: If Gd_func is None, compute Gd automatically with u = u + w

        ukf = UnscentedKalmanFilter(
            dt,
            forward,
            observation,
            Gd_func,
            Qd_func,
            Rd_func,
            x_hat_0,
            P_hat_0,
            alpha=alpha,
            beta=beta,
            kappa=kappa,
            name=name,
            ui_id=ui_id,
        )

        return ukf

    ###################################################################################
    # Make filter from direct specification of forward/observaton operators and noise #
    ###################################################################################

    @staticmethod
    @with_resolved_parameters
    def from_operators(
        dt,
        forward,
        observation,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        P_hat_0,
        alpha=1.0,
        beta=0.0,
        kappa=0.0,
        name=None,
        ui_id=None,
    ):
        """
        Unscented Kalman Filter (UKF) for the following system:

        ```
            x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
            y[n]   = g(x[n], u[n]) + v[n]

            E(w[n]) = E(v[n]) = 0
            E(w[n]w'[n]) = Q(t[n], x[n], u[n])
            E(v[n]v'[n] = R(t[n])
            E(w[n]v'[n] = N(t[n]) = 0
        ```

        `f` and `g` are discrete-time functions of state `x[n]` and control `u[n]`,
        while `Q` and `R` and `G` are discrete-time functions of time `t[n]`.

        Input ports:
            (0) u[n] : control vector at timestep n
            (1) y[n] : measurement vector at timestep n

        Output ports:
            (1) x_hat[n] : state vector estimate at timestep n

        Parameters:
            dt: float
                Time step of the discrete-time system
            forward: Callable
                A function with signature f(x[n], u[n]) -> x[n+1] that represents `f`
                in the above equations.
            observation: Callable
                A function with signature g(x[n], u[n]) -> y[n] that represents `g` in
                the above equations.
            G_func: Callable
                A function with signature G(t[n]) -> G[n] that represents `G` in
                the above equations.
            Q_func: Callable
                A function with signature Q(t[n]) -> Q[n] that represents
                `Q` in the above equations.
            R_func: Callable
                A function with signature R(t[n]) -> R[n] that represents `R` in
                the above equations.
            x_hat_0: ndarray
                Initial state estimate
            P_hat_0: ndarray
                Initial state covariance matrix estimate
            alpha: float
                Sigma point spread to control the amount of nonlinearities taken into
                account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.
            beta: float
                Scaling constant to include prior information about the distribution of
                the state. Default is 0.0.
            kappa: float
                Relatively non-critical parameter to control the kurtosis of sigma
                point distribution. Default is 0.0.
        """

        def Q_func_aug(t, x_k, u_k):
            return Q_func(t)

        ukf = UnscentedKalmanFilter(
            dt,
            forward,
            observation,
            G_func,
            Q_func_aug,
            R_func,
            x_hat_0,
            P_hat_0,
            alpha=alpha,
            beta=beta,
            kappa=kappa,
            name=name,
            ui_id=ui_id,
        )

        return ukf

for_continuous_plant(plant, dt, G_func, Q_func, R_func, x_hat_0, P_hat_0, discretization_method='euler', discretized_noise=False, alpha=1.0, beta=0.0, kappa=0.0, name=None, ui_id=None) staticmethod

Unscented Kalman Filter system for a continuous-time plant.

The input plant contains the deterministic forms of the forward and observation operators:

    dx/dt = f(x,u)
    y = g(x,u)

Note: (i) Only plants with one vector-valued input and one vector-valued output are currently supported. Furthermore, the plant LeafSystem/Diagram should have only one vector-valued integrator; (ii) the user may pass a plant with disturbances (not recommended) as the input plant. In this case, the forward and observation evaluations will be corrupted by noise.

A plant with disturbances of the following form is then considered:

    dx/dt = f(x,u) + G(t) w         -- (C1)
    y = g(x,u) +  v                 -- (C2)

where:

`w` represents the process noise,
`v` represents the measurement noise,

and

    E(w) = E(v) = 0
    E(ww') = Q(t)
    E(vv') = R(t)
    E(wv') = N(t) = 0

This plant is discretized to obtain the following form:

    x[n+1] = fd(x[n], u[n]) + Gd w[n]  -- (D1)
    y[n]   = gd(x[n], u[n]) + v[n]     -- (D2)

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Qd
    E(v[n]v'[n] = Rd
    E(w[n]v'[n] = Nd = 0

The above discretization is performed either via the euler or the zoh method, and an Unscented Kalman Filter estimator for the system of equations (D1) and (D2) is returned.

Note: If discretized_noise is True, then it is assumed that the user is directly providing Gd, Qd and Rd. If False, then Qd and Rd are computed from continuous-time Q, R, and G, and Gd is set to an Identity matrix.

The returned system will have:

Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
plant

a Plant object which can be a LeafSystem or a Diagram.

required
dt

float Time step for the discretization.

required
G_func

Callable A function with signature G(t) -> G that represents G in the continuous-time equations (C1) and (C2).

required
Q_func

Callable A function with signature Q(t) -> Q that represents Q in the continuous-time equations (C1) and (C2).

required
R_func

Callable A function with signature R(t) -> R that represents R in the continuous-time equations (C1) and (C2).

required
x_hat_0

ndarray Initial state estimate

required
P_hat_0

ndarray Initial state covariance matrix estimate. If None, an Identity matrix is assumed.

required
discretization_method

str ("euler" or "zoh") Method to discretize the continuous-time plant. Default is "euler".

'euler'
discretized_noise

bool Whether the user is directly providing Gd, Qd and Rd. Default is False. If True, G_func, Q_func, and R_func provide Gd(t), Qd(t), and Rd(t), respectively.

False
alpha

float Sigma point spread to control the amount of nonlinearities taken into account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.

1.0
beta

float Scaling constant to include prior information about the distribution of the state. Default is 0.0.

0.0
kappa

float Relatively non-critical parameter to control the kurtosis of sigma point distribution. Default is 0.0.

0.0
Source code in collimator/library/state_estimators/unscented_kalman_filter.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
@staticmethod
@with_resolved_parameters
def for_continuous_plant(
    plant,
    dt,
    G_func,
    Q_func,
    R_func,
    x_hat_0,
    P_hat_0,
    discretization_method="euler",
    discretized_noise=False,
    alpha=1.0,
    beta=0.0,
    kappa=0.0,
    name=None,
    ui_id=None,
):
    """
    Unscented Kalman Filter system for a continuous-time plant.

    The input plant contains the deterministic forms of the forward and observation
    operators:

    ```
        dx/dt = f(x,u)
        y = g(x,u)
    ```

    Note: (i) Only plants with one vector-valued input and one vector-valued output
    are currently supported. Furthermore, the plant LeafSystem/Diagram should have
    only one vector-valued integrator; (ii) the user may pass a plant with
    disturbances (not recommended) as the input plant. In this case, the forward
    and observation evaluations will be corrupted by noise.

    A plant with disturbances of the following form is then considered:

    ```
        dx/dt = f(x,u) + G(t) w         -- (C1)
        y = g(x,u) +  v                 -- (C2)
    ```

    where:

        `w` represents the process noise,
        `v` represents the measurement noise,

    and

    ```
        E(w) = E(v) = 0
        E(ww') = Q(t)
        E(vv') = R(t)
        E(wv') = N(t) = 0
    ```

    This plant is discretized to obtain the following form:

    ```
        x[n+1] = fd(x[n], u[n]) + Gd w[n]  -- (D1)
        y[n]   = gd(x[n], u[n]) + v[n]     -- (D2)

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Qd
        E(v[n]v'[n] = Rd
        E(w[n]v'[n] = Nd = 0
    ```

    The above discretization is performed either via the `euler` or the `zoh`
    method, and an Unscented Kalman Filter estimator for the system of equations
    (D1) and (D2) is returned.

    Note: If `discretized_noise` is True, then it is assumed that the user is
    directly providing Gd, Qd and Rd. If False, then Qd and Rd are computed from
    continuous-time Q, R, and G, and Gd is set to an Identity matrix.

    The returned system will have:

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        plant : a `Plant` object which can be a LeafSystem or a Diagram.
        dt: float
            Time step for the discretization.
        G_func: Callable
            A function with signature G(t) -> G that represents `G` in
            the continuous-time equations (C1) and (C2).
        Q_func: Callable
            A function with signature Q(t) -> Q that represents `Q` in
            the continuous-time equations (C1) and (C2).
        R_func: Callable
            A function with signature R(t) -> R that represents `R` in
            the continuous-time equations (C1) and (C2).
        x_hat_0: ndarray
            Initial state estimate
        P_hat_0: ndarray
            Initial state covariance matrix estimate. If `None`, an Identity
            matrix is assumed.
        discretization_method: str ("euler" or "zoh")
            Method to discretize the continuous-time plant. Default is "euler".
        discretized_noise: bool
            Whether the user is directly providing Gd, Qd and Rd. Default is False.
            If True, `G_func`, `Q_func`, and `R_func` provide Gd(t), Qd(t), and
            Rd(t), respectively.
        alpha: float
            Sigma point spread to control the amount of nonlinearities taken into
            account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.
        beta: float
            Scaling constant to include prior information about the distribution of
            the state. Default is 0.0.
        kappa: float
            Relatively non-critical parameter to control the kurtosis of sigma
            point distribution. Default is 0.0.
    """

    (
        forward,
        observation,
        Gd_func,
        Qd_func,
        Rd_func,
    ) = prepare_continuous_plant_for_nonlinear_kalman_filter(
        plant,
        dt,
        G_func,
        Q_func,
        R_func,
        x_hat_0,
        discretization_method,
        discretized_noise,
    )

    nx = x_hat_0.size
    if P_hat_0 is None:
        P_hat_0 = jnp.eye(nx)

    # TODO: If Gd_func is None, compute Gd automatically with u = u + w

    ukf = UnscentedKalmanFilter(
        dt,
        forward,
        observation,
        Gd_func,
        Qd_func,
        Rd_func,
        x_hat_0,
        P_hat_0,
        alpha=alpha,
        beta=beta,
        kappa=kappa,
        name=name,
        ui_id=ui_id,
    )

    return ukf

from_operators(dt, forward, observation, G_func, Q_func, R_func, x_hat_0, P_hat_0, alpha=1.0, beta=0.0, kappa=0.0, name=None, ui_id=None) staticmethod

Unscented Kalman Filter (UKF) for the following system:

    x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
    y[n]   = g(x[n], u[n]) + v[n]

    E(w[n]) = E(v[n]) = 0
    E(w[n]w'[n]) = Q(t[n], x[n], u[n])
    E(v[n]v'[n] = R(t[n])
    E(w[n]v'[n] = N(t[n]) = 0

f and g are discrete-time functions of state x[n] and control u[n], while Q and R and G are discrete-time functions of time t[n].

Input ports

(0) u[n] : control vector at timestep n (1) y[n] : measurement vector at timestep n

Output ports

(1) x_hat[n] : state vector estimate at timestep n

Parameters:

Name Type Description Default
dt

float Time step of the discrete-time system

required
forward

Callable A function with signature f(x[n], u[n]) -> x[n+1] that represents f in the above equations.

required
observation

Callable A function with signature g(x[n], u[n]) -> y[n] that represents g in the above equations.

required
G_func

Callable A function with signature G(t[n]) -> G[n] that represents G in the above equations.

required
Q_func

Callable A function with signature Q(t[n]) -> Q[n] that represents Q in the above equations.

required
R_func

Callable A function with signature R(t[n]) -> R[n] that represents R in the above equations.

required
x_hat_0

ndarray Initial state estimate

required
P_hat_0

ndarray Initial state covariance matrix estimate

required
alpha

float Sigma point spread to control the amount of nonlinearities taken into account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.

1.0
beta

float Scaling constant to include prior information about the distribution of the state. Default is 0.0.

0.0
kappa

float Relatively non-critical parameter to control the kurtosis of sigma point distribution. Default is 0.0.

0.0
Source code in collimator/library/state_estimators/unscented_kalman_filter.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
@staticmethod
@with_resolved_parameters
def from_operators(
    dt,
    forward,
    observation,
    G_func,
    Q_func,
    R_func,
    x_hat_0,
    P_hat_0,
    alpha=1.0,
    beta=0.0,
    kappa=0.0,
    name=None,
    ui_id=None,
):
    """
    Unscented Kalman Filter (UKF) for the following system:

    ```
        x[n+1] = f(x[n], u[n]) + G(t[n]) w[n]
        y[n]   = g(x[n], u[n]) + v[n]

        E(w[n]) = E(v[n]) = 0
        E(w[n]w'[n]) = Q(t[n], x[n], u[n])
        E(v[n]v'[n] = R(t[n])
        E(w[n]v'[n] = N(t[n]) = 0
    ```

    `f` and `g` are discrete-time functions of state `x[n]` and control `u[n]`,
    while `Q` and `R` and `G` are discrete-time functions of time `t[n]`.

    Input ports:
        (0) u[n] : control vector at timestep n
        (1) y[n] : measurement vector at timestep n

    Output ports:
        (1) x_hat[n] : state vector estimate at timestep n

    Parameters:
        dt: float
            Time step of the discrete-time system
        forward: Callable
            A function with signature f(x[n], u[n]) -> x[n+1] that represents `f`
            in the above equations.
        observation: Callable
            A function with signature g(x[n], u[n]) -> y[n] that represents `g` in
            the above equations.
        G_func: Callable
            A function with signature G(t[n]) -> G[n] that represents `G` in
            the above equations.
        Q_func: Callable
            A function with signature Q(t[n]) -> Q[n] that represents
            `Q` in the above equations.
        R_func: Callable
            A function with signature R(t[n]) -> R[n] that represents `R` in
            the above equations.
        x_hat_0: ndarray
            Initial state estimate
        P_hat_0: ndarray
            Initial state covariance matrix estimate
        alpha: float
            Sigma point spread to control the amount of nonlinearities taken into
            account. Usually set to a value (1e-04<= alpha <= 1.0). Default is 1.0.
        beta: float
            Scaling constant to include prior information about the distribution of
            the state. Default is 0.0.
        kappa: float
            Relatively non-critical parameter to control the kurtosis of sigma
            point distribution. Default is 0.0.
    """

    def Q_func_aug(t, x_k, u_k):
        return Q_func(t)

    ukf = UnscentedKalmanFilter(
        dt,
        forward,
        observation,
        G_func,
        Q_func_aug,
        R_func,
        x_hat_0,
        P_hat_0,
        alpha=alpha,
        beta=beta,
        kappa=kappa,
        name=name,
        ui_id=ui_id,
    )

    return ukf

VideoSink

Bases: LeafSystem

Records RGB frames to a video file.

Parameters:

Name Type Description Default
dt float

Interval at which to record frames.

required
file_name str

Name of the video file to write to (optional).

required
Source code in collimator/library/video.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class VideoSink(LeafSystem):
    """Records RGB frames to a video file.

    Parameters:
        dt: Interval at which to record frames.
        file_name: Name of the video file to write to (optional).
    """

    @parameters(static=["dt", "file_name"])
    def __init__(self, dt: float, file_name: str, **kwargs):
        super().__init__(**kwargs)

        self.dt = dt
        self.fps = 1 / dt
        file_name = str(file_name)
        ext = ".mp4" if not file_name.endswith(".mp4") else ""
        self.file_name = file_name + ext
        self.writer: "VideoWriter" = None
        self.frame_id = 0

        self.declare_input_port("frame")

        def _io_cb(time, state, *inputs, **parameters) -> Array:
            return io_callback(self._video_cb, cnp.intx(0), time, inputs[0])

        self.declare_output_port(
            _io_cb,
            name="frame_id",
            requires_inputs=True,
            period=dt,
            offset=dt,
        )

    def _init_video(self, frame: Array):
        if len(frame.shape) != 3 or frame.shape[2] != 3:
            raise StaticError(
                f"Input frame must be an RGB image, got invalid shape: {frame.shape}",
                system=self,
            )

        # A note on codecs:
        # vp9 (vp09) works in browsers, but it's a bit slow to encode
        # MPEG-4 (mp4v) is faster, but not supported in browsers
        # H264 (avc1) is supported but plagued with patents
        # av1 (AV01) broke my computer

        os.makedirs(os.path.dirname(self.file_name), exist_ok=True)

        h, w, _ = frame.shape
        self.writer = cv2.VideoWriter(
            self.file_name,
            cv2.VideoWriter_fourcc(*"vp09"),
            self.fps,
            (w, h),
        )
        if not self.writer.isOpened():
            raise StaticError(
                f"Failed to open video file {self.file_name}",
                system=self,
            )

        logger.info("Writing video of size %sx%s to file: %s", w, h, self.file_name)

    def post_simulation_finalize(self) -> None:
        if self.writer is not None:
            self.writer.release()
        return super().post_simulation_finalize()

    def _video_cb(self, time: Array, frame: Array) -> Array:
        image = np.array(frame)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        if self.writer is None:
            self._init_video(image)
        self.writer.write(image)

        # jax.debug.print(
        #     "Wrote frame {frame_id} to video file at time {time}",
        #     frame_id=self.frame_id,
        #     time=time,
        # )

        frame_id = self.frame_id
        self.frame_id += 1
        return cnp.intx(frame_id)

VideoSource

Bases: LeafSystem

Reads frames from a video file.

Parameters:

Name Type Description Default
file_name str

Name of the video file to read from.

required
no_repeat

Whether to stop at the end of the video or loop back to the beginning.

False
Source code in collimator/library/video.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class VideoSource(LeafSystem):
    """Reads frames from a video file.

    Parameters:
        file_name: Name of the video file to read from.
        no_repeat: Whether to stop at the end of the video or loop back to the beginning.
    """

    @parameters(static=["file_name", "no_repeat"])
    def __init__(self, file_name: str, no_repeat=False, **kwargs):
        super().__init__(**kwargs)

        self.repeat = not no_repeat
        self.file_name = str(file_name)
        self.frame_id: np.intx = 0
        self.reached_end = False

        self.reader = cv2.VideoCapture(self.file_name)
        if not self.reader.isOpened():
            raise BlockInitializationError(
                f"Could not open video file '{self.file_name}'", system=self
            )

        self.width = int(self.reader.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(self.reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.depth = 1 if bool(self.reader.get(cv2.CAP_PROP_MONOCHROME)) else 3
        self.fps = self.reader.get(cv2.CAP_PROP_FPS) or 30
        self.video_length = int(self.reader.get(cv2.CAP_PROP_FRAME_COUNT))

        logger.info(
            "Opened video file '%s' with size %sx%s, %s frames, %s fps",
            self.file_name,
            self.width,
            self.height,
            self.video_length,
            self.fps,
            **logdata(block=self),
        )

        self.last_frame = np.zeros(
            (self.height, self.width, self.depth), dtype=np.uint8
        )

        def _frame_cb(time, state, *inputs, **parameters) -> Array:
            def cb(time) -> Array:
                return self._source_cb(time)

            return io_callback(cb, self.last_frame, time)

        dt = 1 / self.fps
        self.declare_output_port(
            _frame_cb,
            name="frame",
            period=dt,
            offset=dt,
            requires_inputs=False,
        )

        def _frame_id_cb(time, state, *inputs, **parameters) -> Array:
            return io_callback(self._frame_id_cb, cnp.intx(0))

        self.declare_output_port(
            _frame_id_cb,
            name="frame_id",
            period=dt,
            offset=dt,
            default_value=cnp.intx(0),
            requires_inputs=False,
        )

        if not self.repeat:

            def _stopped_cb(time, state, *inputs, **parameters) -> Array:
                return io_callback(self._stopped_cb, cnp.bool_(False))

            self.declare_output_port(
                _stopped_cb,
                name="stopped",
                period=dt,
                offset=dt,
                default_value=cnp.bool_(False),
                requires_inputs=False,
            )

    def post_simulation_finalize(self) -> None:
        if self.reader is not None:
            self.reader.release()
        return super().post_simulation_finalize()

    def _source_cb(self, time: float) -> Array:
        if self.reached_end:
            return self.last_frame

        if not self.repeat and int(time * self.fps + FPS_EPSILON) >= self.video_length:
            self.reached_end = True
            return self.last_frame

        # jax.debug.print(
        #     "Reading frame {frame_id} to video file at time {time}",
        #     frame_id=self.reader.get(cv2.CAP_PROP_POS_FRAMES),
        #     time=time,
        # )

        self.frame_id = int(time * self.fps + FPS_EPSILON) % self.video_length
        self.reader.set(cv2.CAP_PROP_POS_FRAMES, self.frame_id)

        ret, frame = self.reader.read()
        if not ret:
            if self.repeat:
                self.reader.set(cv2.CAP_PROP_POS_FRAMES, 0)
                ret, frame = self.reader.read()
            else:
                self.reached_end = True
                self.reader.release()
                self.reader = None
                return self.last_frame

        if not ret:
            raise BlockRuntimeError("Failed to read frame from video file", system=self)

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        self.last_frame = frame
        return frame

    def _frame_id_cb(self) -> Array:
        return cnp.intx(self.frame_id)

    def _stopped_cb(self) -> Array:
        return self.reached_end

WhiteNoise

Bases: LeafSystem

Continuous-time white noise generator.

Generates a band-limited white noise signal using a sinc-interpolated random number generator. The output signal is a continuous-time signal, but the underlying random number generator is discrete-time. As a result, the signal is not truly white, but is band-limited by the sample rate. The resulting signal has the following approximate power spectral density:

S(f) = A * fs if |f| < fs else 0,

where A is the noise power and fs = 1/dt is the sample rate.

See Ch. 10.4 in Baraniuk, "Signal Processing and Modeling" for details: https://shorturl.at/floRZ

The output signal will have variance A, zero mean, and will decorrelate at the sample rate.

Input ports

None

Output ports

(0) The band-limited white noise signal with variance noise_power, zero mean, and correlation time dt.

Parameters:

Name Type Description Default
correlation_time

The correlation time of the output signal and the inverse of the bandwidth. It is the sample frequency of the underlying random number generator.

required
noise_power float

The variance of the white noise signal. Also scales the amplitude of the power spectral density.

1.0
num_samples int

The number of samples to use for sinc interpolation. More samples will result in a more accurate approximation of the ideal power spectrum, but will also increase the computational cost. The default of 10 is sufficient for most applications.

10
seed int

An integer seed for the random number generator. If None, a random 32-bit seed will be generated.

None
dtype DTypeLike

data type of the random number. If None, defaults to float.

None
shape ShapeLike

The shape of the output signal. If empty, the output will be a scalar.

()
Source code in collimator/library/random.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class WhiteNoise(LeafSystem):
    """Continuous-time white noise generator.

    Generates a band-limited white noise signal using a sinc-interpolated random
    number generator.  The output signal is a continuous-time signal, but the
    underlying random number generator is discrete-time.  As a result, the signal
    is not truly white, but is band-limited by the sample rate.  The resulting signal
    has the following approximate power spectral density:
    ```
    S(f) = A * fs if |f| < fs else 0,
    ```
    where `A` is the noise power and `fs = 1/dt` is the sample rate.

    See Ch. 10.4 in Baraniuk, "Signal Processing and Modeling" for details:
        https://shorturl.at/floRZ

    The output signal will have variance `A`, zero mean, and will decorrelate at
    the sample rate.

    Input ports:
        None

    Output ports:
        (0) The band-limited white noise signal with variance `noise_power`, zero
            mean, and correlation time `dt`.

    Parameters:
        correlation_time: The correlation time of the output signal and the inverse of
            the bandwidth. It is the sample frequency of the underlying random number
            generator.
        noise_power: The variance of the white noise signal. Also scales the amplitude
            of the power spectral density.
        num_samples: The number of samples to use for sinc interpolation.  More samples
            will result in a more accurate approximation of the ideal power spectrum,
            but will also increase the computational cost.  The default of 10 is
            sufficient for most applications.
        seed: An integer seed for the random number generator. If None, a random 32-bit
            seed will be generated.
        dtype: data type of the random number.  If None, defaults to float.
        shape: The shape of the output signal.  If empty, the output will be a scalar.
    """

    class RNGState(NamedTuple):
        key: Array
        samples: Array
        t_last: float = 0.0

    @parameters(
        static=["num_samples", "shape", "seed", "noise_power"],
        dynamic=["correlation_time"],
    )
    def __init__(
        self,
        correlation_time,
        noise_power: float = 1.0,
        num_samples: int = 10,
        seed: int = None,
        dtype: DTypeLike = None,
        shape: ShapeLike = (),
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.dtype = dtype

        self.declare_output_port(self._output)
        self.declare_periodic_update(
            self._update,
            period=correlation_time,
            offset=0.0,
        )

    def initialize(
        self,
        correlation_time,
        noise_power: float = 1.0,
        num_samples: int = 10,
        seed: int = None,
        shape: ShapeLike = (),
    ):
        self.shape = tuple(map(int, shape))
        self.N = num_samples

        self.noise_power = noise_power
        self.shift = np.arange(self.N) - (self.N - 1) / 2
        self.rng = partial(random.normal, dtype=self.dtype)

        # The default state is a tuple of (key, samples) pairs.  The continuous-time
        # output signal is reconstructed from the samples using a sinc interpolation.
        seed = (
            np.random.randint(0, 2**32, dtype=np.int64) if seed is None else int(seed)
        )
        key = random.PRNGKey(int(seed))
        key, subkey = random.split(key)
        default_state = self.RNGState(
            key=key,
            samples=self.sample(subkey, shape=(self.N, *self.shape)),
        )
        self.declare_discrete_state(default_value=default_state, as_array=False)

    def sample(self, key, shape):
        return jnp.sqrt(self.noise_power) * self.rng(key, shape)

    def _output(self, time, state, *_inputs, **parameters):
        t_last = state.discrete_state.t_last

        # Time relative to the last discrete sample, in units of
        # samples.  This is the argument to the sinc function.
        w = (time - t_last) / parameters["correlation_time"] - self.shift

        # Clip the time values to limit discontinuities resulting
        # from sample updates.
        w = jnp.clip(w, -self.N // 2, self.N // 2)

        # Shift the axes so that the last axis is the sample index.
        # This is the index that will be contracted over
        samples = jnp.moveaxis(state.discrete_state.samples, 0, -1)

        return jnp.sum(samples * jnp.sinc(w), axis=-1)

    def _update(self, time, state, *_inputs, **_parameters):
        key, subkey = random.split(state.discrete_state.key)

        new_sample = self.sample(subkey, (1, *self.shape))
        samples = jnp.concatenate((state.discrete_state.samples[1:], new_sample))

        return self.RNGState(
            key=key,
            samples=samples,
            t_last=time,
        )

ZeroOrderHold

Bases: LeafSystem

Implements a "zero-order hold" A/D conversion.

https://en.wikipedia.org/wiki/Zero-order_hold

The block implements a "zero-order hold" with the following difference equation for input signal u and output signal y:

    y[k] = u[k]

The block does not maintain an internal state, but simply holds the value of the input signal at the previous update time. As a result, the block is "feedthrough" from its inputs to outputs and cannot be used to break an algebraic loop. The data type of this hold value is inferred from upstream blocks.

Input ports

(0) The input signal.

Output ports

(0) The "hold" value of the input signal. If the input signal is continuous, then the output will be the value of the input signal at the previous update time. If the input signal is discrete and synchonous with the block, the output will be the value of the input signal at the current time (i.e. identical to the input signal).

Parameters:

Name Type Description Default
dt

The time step of the discrete update.

required
Source code in collimator/library/primitives.py
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
class ZeroOrderHold(LeafSystem):
    """Implements a "zero-order hold" A/D conversion.

    https://en.wikipedia.org/wiki/Zero-order_hold

    The block implements a "zero-order hold" with the following difference equation
    for input signal `u` and output signal `y`:
    ```
        y[k] = u[k]
    ```

    The block does not maintain an internal state, but simply holds the value of the
    input signal at the previous update time.  As a result, the block is "feedthrough"
    from its inputs to outputs and cannot be used to break an algebraic loop. The data
    type of this hold value is inferred from upstream blocks.

    Input ports:
        (0) The input signal.

    Output ports:
        (0) The "hold" value of the input signal.  If the input signal is continuous,
            then the output will be the value of the input signal at the previous
            update time.  If the input signal is discrete and synchonous with the
            block, the output will be the value of the input signal at the current
            time (i.e. identical to the input signal).

    Parameters:
        dt:
            The time step of the discrete update.
    """

    def __init__(self, dt, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dt = dt

        self.declare_input_port()
        self.declare_output_port(
            self._output,
            period=dt,
            offset=0.0,
            prerequisites_of_calc=[self.input_ports[0].ticket, DependencyTicket.xd],
        )

    def _output(self, _time, _state, u, **_params):
        # Every dt seconds, update the state to the current input value
        return u

linearize(system, base_context, name=None, output_index=None)

Linearize the system about an operating point specified by the base context.

For now, only implemented for systems with one each (vector-valued) input and output. The system may have multiple output ports, but only one will be treated as the measurement.

Source code in collimator/library/linear_system.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def linearize(system, base_context, name=None, output_index=None):
    """Linearize the system about an operating point specified by the base context.

    For now, only implemented for systems with one each (vector-valued) input and
    output. The system may have multiple output ports, but only one will be treated
    as the measurement.
    """
    assert len(system.input_ports) == 1, (
        "Linearization only implemented for systems with one input port, system "
        f"{system.name} has {len(system.input_ports)} input ports"
    )
    if len(system.output_ports) > 1:
        if output_index is None:
            logger.warning(
                "Multiple output ports detected when linearizing system %s, "
                "using first port as output",
                system.name,
            )

    # Default to zero output index if not specified (after issuing a warning)
    if output_index is None:
        output_index = 0

    input_port = system.input_ports[0]
    output_port = system.output_ports[output_index]

    xc0 = base_context.continuous_state
    u0 = input_port.eval(base_context)

    restore_fixed_val = input_port.is_fixed

    # Map from (state, inputs) to (state derivatives, outputs)
    @jax.jit
    def f(xc, u):
        context = base_context.with_continuous_state(xc)
        with input_port.fixed(u):
            xdot = system.eval_time_derivatives(context)
            y = output_port.eval(context)
        return xdot, y

    @jax.jit
    def jac(xc, u):
        primals, tangents = jax.jvp(f, (xc0, u0), (xc, u))
        return tangents

    lin_sys = LTISystem(*_jvp_to_ss(jac, xc0, u0), name=name)
    lin_sys.create_context()

    if restore_fixed_val:
        input_port.fix_value(u0)

    return lin_sys