Skip to content

Optimization

collimator.optimization

AutoTuner

PID autotuning (without a measurement filter) with constraints in the frequency domain.

Supports only SISO systems.

Supports only continuous-time plants (TODO: extend to discrete-time systems)

Parameters:

Name Type Description Default
plant

LeafSystem or a Diagram. If plant is not an LTISystem, operating points x_op and u_op must be provided for linearization.

required
n

int, optional Filter coefficient for the continuous-time PID controller

100
sim_time

float, optional Simulation time for computation of the error metric

2.0
metric

str, optional Error metric to be minimized. Options are "IAE" and "IE" "IAE": Integral of the absolute error "IE": Integral of the error

'IAE'
x_op

np.ndarray, optional Operating point of state vector for linearization

None
u_op

np.ndarray, optional Operating point of control vector for linearization

None
pid_gains_0

list or Array, optional Initial guess for PID gains [kp, ki, kd]

[1.0, 10.0, 0.1]
pid_gains_upper_bounds

list or Array, optional Upper bounds for PID gains [kp, ki, kd]. Lower bounds are set to 0

None
Ms

float, optional Maximum sensitivity

100.0
Mt

float, optional Maximum complementary sensitivity

100.0
add_filter

bool, optional Add measurement filter (currently not implemented)

False
method

str, optional The method for optimization. Available options are: - "scipy-slsqp" - "scipy-cobyla" - "scipy-trust-constr" - "ipopt" - "nlopt-slsqp" - "nlopt-cobyla" - "nlopt-ld_mma" - "nlopt-isres" - "nlopt-ags" - "nlopt-direct"

'scipy-slsqp'

Notes:

The utilities plot_freq_response, plot_time_response, and plot_freq_and_time_responses can be used to visualize the frequency and time responses of the closed-loop system.

Post initialization the tune method should be called to obtain the optimal PID gains. See notebooks/opt_framework/pid_autotuning.ipynb for an example.

Source code in collimator/optimization/pid_autotuning.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
class AutoTuner:
    """
    PID autotuning (without a measurement filter) with constraints in the frequency
    domain.

    Supports only SISO systems.

    Supports only continuous-time plants (TODO: extend to discrete-time systems)

    Parameters:
        plant: LeafSystem or a Diagram.
            If plant is not an LTISystem, operating points x_op and u_op must be
            provided for linearization.
        n: int, optional
            Filter coefficient for the continuous-time PID controller
        sim_time: float, optional
            Simulation time for computation of the error metric
        metric: str, optional
            Error metric to be minimized. Options are "IAE" and "IE"
                "IAE": Integral of the absolute error
                "IE": Integral of the error
        x_op: np.ndarray, optional
            Operating point of state vector for linearization
        u_op: np.ndarray, optional
            Operating point of control vector for linearization
        pid_gains_0: list or Array, optional
            Initial guess for PID gains [kp, ki, kd]
        pid_gains_upper_bounds: list or Array, optional
            Upper bounds for PID gains [kp, ki, kd]. Lower bounds are set to 0
        Ms: float, optional
            Maximum sensitivity
        Mt: float, optional
            Maximum complementary sensitivity
        add_filter: bool, optional
            Add measurement filter (currently not implemented)
        method: str, optional
            The method for optimization. Available options are:
                - "scipy-slsqp"
                - "scipy-cobyla"
                - "scipy-trust-constr"
                - "ipopt"
                - "nlopt-slsqp"
                - "nlopt-cobyla"
                - "nlopt-ld_mma"
                - "nlopt-isres"
                - "nlopt-ags"
                - "nlopt-direct"

    Notes:

    The utilities `plot_freq_response`, `plot_time_response`, and
    `plot_freq_and_time_responses` can be used to visualize the frequency and time
    responses of the closed-loop system.

    Post initialization the `tune` method should be called to obtain the optimal PID
    gains. See `notebooks/opt_framework/pid_autotuning.ipynb` for an example.

    """

    def __init__(
        self,
        plant,
        n=100,
        sim_time=2.0,
        metric="IAE",
        x_op=None,
        u_op=None,
        pid_gains_0=[1.0, 10.0, 0.1],
        pid_gains_upper_bounds=None,
        Ms=100.0,
        Mt=100.0,
        add_filter=False,  # add measurement filter (currently not implemented)
        method="scipy-slsqp",
    ):
        if isinstance(plant, LTISystem):  # LTISystem includes TransferFunction
            linear_plant = plant
        else:
            if x_op is None or u_op is None:
                raise ValueError("Operating point x_op and u_op must be provided")

            _, linear_plant = linearize_plant(plant, x_op, u_op)

        if linear_plant.B.shape[1] != 1 or linear_plant.C.shape[0] != 1:
            raise ValueError("Plant must be SISO")

        self.A, self.B, self.C, self.D = (
            linear_plant.A,
            linear_plant.B,
            linear_plant.C,
            linear_plant.D,
        )

        self.pid, self.integrator, self.diagram = make_closed_loop_pid_system(
            linear_plant,
            metric,
            n=n,
        )

        self.lb = [0.0] * 3
        if pid_gains_upper_bounds is None:
            self.ub = [jnp.inf] * 3
        else:
            self.ub = pid_gains_upper_bounds

        self.base_context = self.diagram.create_context()

        self.n = n
        self.sim_time = sim_time
        self.metric = metric
        self.pid_gains_0 = pid_gains_0
        self.Ms = Ms
        self.Mt = Mt
        self.add_filter = add_filter
        self.method = method

        self.omega_grid = 10.0 ** jnp.linspace(-2, 2, 1000)
        # self.omega_grid = 10.0 ** jnp.linspace(-1, 2, 150)
        self.options = SimulatorOptions(
            enable_autodiff=True,
            max_major_step_length=0.01,  # rtol=1e-08, atol=1e-10
        )

        self.circle_constraint_vectorized = jax.vmap(
            self.circle_constraint_, in_axes=(None, None, None, 0, None, None)
        )  # Deprecated

        self.Ps_vectorized = jax.vmap(self.Ps, in_axes=0)
        self.Cs_vectorized = jax.vmap(self.Cs, in_axes=(None, None, None, 0))
        self.vec_absolute = jax.vmap(jnp.absolute)

    @partial(jax.jit, static_argnums=(0,))
    def objective(self, pid_params):
        kp, ki, kd = pid_params
        C = jnp.array(
            [[(ki * self.n), ((kp * self.n + ki) - (kp + kd * self.n) * self.n)]]
        )
        D = jnp.array([[(kp + kd * self.n)]])
        pid_subcontext = self.base_context[self.pid.system_id].with_parameters(
            {"C": C, "D": D}
        )
        context = self.base_context.with_subcontext(self.pid.system_id, pid_subcontext)
        sol = collimator.simulate(
            self.diagram, context, (0.0, self.sim_time), options=self.options
        )
        return self.integrator.output_ports[0].eval(sol.context) / self.sim_time

    @partial(jax.jit, static_argnums=(0,))
    def Ps(self, s):
        P = (
            self.C @ jnp.linalg.inv(s * jnp.eye(self.A.shape[0]) - self.A) @ self.B
            + self.D
        )
        return P[0, 0]

    @partial(jax.jit, static_argnums=(0,))
    def Cs(self, kp, ki, kd, s):
        return kp + ki / s + kd * s

    @partial(jax.jit, static_argnums=(0,))
    def circle_constraint_(self, kp, ki, kd, omega, c, r):
        """Deprecated: this is needed for `self.constraints_` which is deprecated
        and replaced by `self.constraints`.
        """
        s = omega * 1.0j
        L = self.Ps(s) * self.Cs(kp, ki, kd, s)
        return jnp.absolute(L - c) - r

    @partial(jax.jit, static_argnums=(0,))
    def constraints_(self, pid_params):
        """Deprecated: replaced by `self.constraints`"""
        kp, ki, kd = pid_params
        Ms, Mt = self.Ms, self.Mt
        g_Ms = self.circle_constraint_vectorized(
            kp, ki, kd, self.omega_grid, -1.0, 1.0 / Ms
        )
        g_Mt = self.circle_constraint_vectorized(
            kp, ki, kd, self.omega_grid, -(Mt**2) / (Mt**2 - 1.0), Mt / (Mt**2 - 1.0)
        )
        return jnp.array([jnp.min(g_Ms), jnp.min(g_Mt)])

    @partial(jax.jit, static_argnums=(0,))
    def constraints(self, pid_params):
        kp, ki, kd = pid_params
        S_grid = 1.0 / (
            1.0
            + self.Ps_vectorized(self.omega_grid * 1.0j)
            * self.Cs_vectorized(kp, ki, kd, self.omega_grid * 1.0j)
        )
        T_grid = 1.0 - S_grid

        S_grid = self.vec_absolute(S_grid)
        T_grid = self.vec_absolute(T_grid)

        return jnp.array([self.Ms - jnp.max(S_grid), self.Mt - jnp.max(T_grid)])

    def tune(self):
        x0 = jnp.array(self.pid_gains_0)
        bounds = list(zip(self.lb, self.ub))

        obj = jax.jit(self.objective)
        cons = jax.jit(self.constraints)

        obj_grad = jax.grad(self.objective)
        obj_hess = jax.jit(jax.hessian(self.objective))

        cons_jac = jax.jit(jax.jacfwd(self.constraints))

        print(f"Tuning with {self.method}")
        if self.method in SCIPY_METHODS:
            constraints_scipy = NonlinearConstraint(cons, 0.0, jnp.inf, jac=cons_jac)

            res = minimize(
                obj,
                x0,
                jac=obj_grad,
                method=SCIPY_METHODS[self.method],
                bounds=bounds,
                constraints=constraints_scipy,
                options={"maxiter": 100},
            )

        elif self.method == "ipopt":
            cons_hess = jax.hessian(self.constraints)
            cons_hess_vp = jax.jit(
                lambda x, v: jnp.sum(
                    # pylint: disable-next=not-callable
                    jnp.multiply(v[:, jnp.newaxis, jnp.newaxis], cons_hess(x)),
                    axis=0,
                )
            )

            constraints_ipopt = [
                {"type": "ineq", "fun": cons, "jac": cons_jac, "hess": cons_hess_vp}
            ]

            res = cyipopt.minimize_ipopt(
                obj,
                x0=x0,
                jac=obj_grad,
                hess=obj_hess,
                constraints=constraints_ipopt,
                bounds=bounds,
                options={
                    "max_iter": 500,
                    "disp": 5,
                },
            )

        elif self.method in NLOPT_METHODS:
            if self.method in NLOPT_METHODS_GLOBAL and any(
                ub == jnp.inf for ub in self.ub
            ):
                raise ValueError(
                    f"Method {self.method} requires finite upper bounds for all "
                    "parameters. Please specify `pid_gains_upper_bounds`."
                )

            # Define the objective function for nlopt
            def nlopt_obj(x, grad):
                if grad.size > 0:
                    grad[:] = obj_grad(jnp.array(x))
                # pylint: disable-next=not-callable
                return float(obj(jnp.array(x)))

            # Define the objective function for nlopt
            def nlopt_cons(result, x, grad):
                if grad.size > 0:
                    # pylint: disable-next=not-callable
                    grad[:, :] = -cons_jac(jnp.array(x))
                # pylint: disable-next=not-callable
                result[:] = -cons(jnp.array(x))

            # Initialize nlopt optimizer
            method = NLOPT_METHODS[self.method]()
            opt = nlopt.opt(method, len(x0))

            # Set the objective function
            opt.set_min_objective(nlopt_obj)

            # Set the constraints
            opt.add_inequality_mconstraint(nlopt_cons, [1e-6, 1e-06])

            # Set the bounds
            lower_bounds, upper_bounds = zip(*bounds)
            opt.set_lower_bounds(lower_bounds)
            opt.set_upper_bounds(upper_bounds)

            # Set stopping criteria
            opt.set_maxeval(500)
            opt.set_ftol_rel(1e-5)
            opt.set_xtol_rel(1e-6)
            opt.set_maxtime(30.0)

            # Run the optimization
            x_opt = opt.optimize(x0)
            print(f"{x_opt=}")
            minf = opt.last_optimum_value()

            nlopt_success_codes = {
                nlopt.SUCCESS: "SUCCESS",
                nlopt.STOPVAL_REACHED: "STOPVAL_REACHED",
                nlopt.FTOL_REACHED: "FTOL_REACHED",
                nlopt.XTOL_REACHED: "XTOL_REACHED",
                nlopt.MAXEVAL_REACHED: "MAXEVAL_REACHED",
                nlopt.MAXTIME_REACHED: "MAXTIME_REACHED",
            }

            nlopt_error_codes = {
                nlopt.FAILURE: "FAILURE",
                nlopt.INVALID_ARGS: "INVALID_ARGS",
                nlopt.OUT_OF_MEMORY: "OUT_OF_MEMORY",
                nlopt.ROUNDOFF_LIMITED: "ROUNDOFF_LIMITED",
                nlopt.FORCED_STOP: "FORCED_STOP",
            }

            nlopt_status_codes = {**nlopt_success_codes, **nlopt_error_codes}

            res = OptResults(
                x=x_opt,
                fun=minf,
                success=opt.last_optimize_result() in nlopt_success_codes,
                message=nlopt_status_codes[opt.last_optimize_result()],
            )

        else:
            raise ValueError("Invalid method")
        return res.x, res

    def plot_freq_response(
        self, pid_params, plant_tf_num, plant_tf_den, Ms=None, Mt=None
    ):
        if Ms is None:
            Ms = self.Ms

        if Mt is None:
            Mt = self.Mt

        kp, ki, kd = pid_params
        Cs = ct.TransferFunction([kd, kp, ki], [1, 0], name="PID")
        Ps = ct.TransferFunction(plant_tf_num, plant_tf_den, name="Plant")

        # Plot Gang of Four transfer functions
        ct.gangof4_plot(Ps, Cs, omega=self.omega_grid)

        fig1 = plt.gcf()
        axs = fig1.get_axes()

        axs[3].set_title(r"$T = \dfrac{PC}{1+PC}$")
        axs[1].set_title(r"$PS = \dfrac{P}{1+PC}$")
        axs[2].set_title(r"$CS = \dfrac{C}{1+PC}$")
        axs[0].set_title(r"$S = \dfrac{1}{1+PC}$")

        if Ms is not None:
            axs[0].hlines(
                Ms,
                self.omega_grid.min(),
                self.omega_grid.max(),
                colors="r",
                linestyles="--",
            )

        if Mt is not None:
            axs[3].hlines(
                Mt,
                self.omega_grid.min(),
                self.omega_grid.max(),
                colors="b",
                linestyles="--",
            )

        # Set x-axis labels for the bottom plots
        axs[2].set_xlabel("Frequency (rad/sec)")
        axs[3].set_xlabel("Frequency (rad/sec)")

        fig1.tight_layout()

        # Plot Nyquist plot
        fig2 = plt.figure()
        ct.nyquist_plot(
            Ps * Cs, omega=self.omega_grid, warn_nyquist=False, warn_encirclements=False
        )
        axs = fig2.get_axes()
        ax = axs[0]

        def gen_circle_points(c, r):
            t = jnp.linspace(-jnp.pi / 2, jnp.pi / 2, 100)
            return jnp.array([c + r * jnp.cos(t), r * jnp.sin(t)])

        if Ms is not None:
            c1, r1 = -1.0, 1.0 / Ms
            xc1, yc1 = gen_circle_points(c1, r1)
            ax.plot(xc1, yc1, "r--")

        if Mt is not None:
            c2, r2 = -(Mt**2) / (Mt**2 - 1.0), Mt / (Mt**2 - 1.0)
            xc2, yc2 = gen_circle_points(c2, r2)
            ax.plot(xc2, yc2, "b--")

        ax.set_title("Nyquist Plot")
        fig2.tight_layout()

        return fig1, fig2

    def plot_time_response(self, pid_params):
        kp, ki, kd = pid_params
        C = jnp.array(
            [[(ki * self.n), ((kp * self.n + ki) - (kp + kd * self.n) * self.n)]]
        )
        D = jnp.array([[(kp + kd * self.n)]])
        pid_subcontext = self.base_context[self.pid.system_id].with_parameters(
            {"C": C, "D": D}
        )
        context = self.base_context.with_subcontext(self.pid.system_id, pid_subcontext)

        recorded_signals = {
            "objective": self.diagram["integrator"].output_ports[0],
            "ref": self.diagram["ref"].output_ports[0],
            "plant": self.diagram.output_ports[0],
            "pid": self.diagram["pid"].output_ports[0],
        }

        sol = collimator.simulate(
            self.diagram,
            context,
            (0.0, self.sim_time),
            recorded_signals=recorded_signals,
        )

        fig, (ax1, ax2, ax3) = plt.subplots(3, 1)

        ax1.plot(sol.time, sol.outputs["plant"], label=r"plant: $y$")
        ax1.plot(sol.time, sol.outputs["ref"], label=r"reference: $y_r$")
        ax2.plot(sol.time, sol.outputs["objective"], label=f"objective: {self.metric}")
        ax3.plot(sol.time, sol.outputs["pid"], label=r"pid-control: $u$")

        ax3.set_xlabel("Time (s)")
        for ax in (ax1, ax2, ax3):
            ax.legend()

        fig.suptitle("Time Response")
        fig.tight_layout()

        print(
            f"objective = "
            f"{self.integrator.output_ports[0].eval(sol.context)/self.sim_time}"
        )
        return fig

    def plot_freq_and_time_responses(
        self, pid_params, plant_tf_num, plant_tf_den, Ms=None, Mt=None
    ):
        fig1, fig2 = self.plot_freq_response(
            pid_params, plant_tf_num, plant_tf_den, Ms, Mt
        )
        fig3 = self.plot_time_response(pid_params)
        return fig1, fig2, fig3

circle_constraint_(kp, ki, kd, omega, c, r)

Deprecated: this is needed for self.constraints_ which is deprecated and replaced by self.constraints.

Source code in collimator/optimization/pid_autotuning.py
273
274
275
276
277
278
279
280
@partial(jax.jit, static_argnums=(0,))
def circle_constraint_(self, kp, ki, kd, omega, c, r):
    """Deprecated: this is needed for `self.constraints_` which is deprecated
    and replaced by `self.constraints`.
    """
    s = omega * 1.0j
    L = self.Ps(s) * self.Cs(kp, ki, kd, s)
    return jnp.absolute(L - c) - r

constraints_(pid_params)

Deprecated: replaced by self.constraints

Source code in collimator/optimization/pid_autotuning.py
282
283
284
285
286
287
288
289
290
291
292
293
@partial(jax.jit, static_argnums=(0,))
def constraints_(self, pid_params):
    """Deprecated: replaced by `self.constraints`"""
    kp, ki, kd = pid_params
    Ms, Mt = self.Ms, self.Mt
    g_Ms = self.circle_constraint_vectorized(
        kp, ki, kd, self.omega_grid, -1.0, 1.0 / Ms
    )
    g_Mt = self.circle_constraint_vectorized(
        kp, ki, kd, self.omega_grid, -(Mt**2) / (Mt**2 - 1.0), Mt / (Mt**2 - 1.0)
    )
    return jnp.array([jnp.min(g_Ms), jnp.min(g_Mt)])

CompositeTransform

Bases: Transform

A composite transformation that applies a list of transformations in sequence.

Source code in collimator/optimization/framework/base/transformations.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class CompositeTransform(Transform):
    """
    A composite transformation that applies a list of transformations in sequence.
    """

    def __init__(self, transformations):
        self.transformations = transformations

    def transform(self, params: dict):
        for transformation in self.transformations:
            params = transformation.transform(params)
        return params

    def inverse_transform(self, params: dict):
        for transformation in reversed(self.transformations):
            params = transformation.inverse_transform(params)
        return params

DistributionConfig dataclass

Structure of attributes for specifying distributions for stochastic variables

Source code in collimator/optimization/framework/base/optimizable.py
58
59
60
61
62
63
64
65
66
67
@dataclass
class DistributionConfig:
    """
    Structure of attributes for specifying distributions for stochastic variables
    """

    names: list[str]
    shapes: list[tuple]
    distributions: list[str]
    distributions_configs: list[dict]

Evosax

Bases: Optimizer

Population based global optimizers from Evosax.

Parameters:

Name Type Description Default
optimizable Optimizable

The optimizable object.

required
opt_method str

The optimization method to use. See evosax.Strategies for available methods.

'CMA_ES'
opt_method_config dict

Configuration for the optimization method.

None
pop_size int

The population size.

10
num_generations int

The number of generations.

100
print_every int

Print progress every print_every generations.

1
metrics_writer MetricsWriter | None

Optional CSV file to write metrics to.

None
seed int

The random seed.

None
Source code in collimator/optimization/framework/optimizers_evosax.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class Evosax(Optimizer):
    """
    Population based global optimizers from Evosax.

    Parameters:
        optimizable (Optimizable):
            The optimizable object.
        opt_method (str):
            The optimization method to use. See `evosax.Strategies` for
            available methods.
        opt_method_config (dict):
            Configuration for the optimization method.
        pop_size (int):
            The population size.
        num_generations (int):
            The number of generations.
        print_every (int):
            Print progress every `print_every` generations.
        metrics_writer (MetricsWriter|None):
            Optional CSV file to write metrics to.
        seed (int):
            The random seed.
    """

    def __init__(
        self,
        optimizable: Optimizable,
        opt_method="CMA_ES",
        opt_method_config=None,
        pop_size=10,
        num_generations=100,
        print_every=1,
        seed=None,
        metrics_writer: MetricsWriter = None,
    ):
        self.optimizable = optimizable
        self.opt_method = opt_method
        self.pop_size = pop_size
        self.num_generations = num_generations
        self.print_every = print_every
        self.metrics_writer = metrics_writer
        self.optimal_params = None

        self.num_dims = optimizable.params_0_flat.size

        if self.optimizable.has_constraints:
            raise ValueError(
                f"Optimization method evosax:{self.opt_method} "
                "does not support constraints."
            )

        if opt_method not in evosax.Strategies:
            raise ValueError(f"Unknown optimization method: {opt_method}")

        if opt_method_config is None:
            opt_method_config = {}

        self.strategy = evosax.Strategies[opt_method](self.pop_size, self.num_dims)
        self.es_params = self.strategy.default_params.replace(**opt_method_config)

        # Create bounds
        if optimizable.bounds_flat is not None:
            lower_bounds, upper_bounds = zip(*optimizable.bounds_flat)
            lb = jnp.array(lower_bounds)
            ub = jnp.array(upper_bounds)
            self.es_params = self.es_params.replace(clip_min=lb, clip_max=ub)

        # Create initialization bounds
        if optimizable.init_min_max_flat is not None:
            init_min, init_max = zip(*optimizable.init_min_max_flat)
            imin = jnp.array(init_min)
            imax = jnp.array(init_max)
            self.es_params = self.es_params.replace(init_min=imin, init_max=imax)

        else:
            # if bounds are specified, unless they are infinity, we can use
            # them for initialization. If infinity, we initialize in [-0.1,0.1]
            if optimizable.bounds_flat is not None:
                bounds = [
                    (
                        -0.1 if b[0] == -jnp.inf else b[0],
                        0.1 if b[1] == jnp.inf else b[1],
                    )
                    for b in optimizable.bounds_flat
                ]
                lower_bounds, upper_bounds = zip(*bounds)
                lb = jnp.array(lower_bounds)
                ub = jnp.array(upper_bounds)
                self.es_params = self.es_params.replace(init_min=lb, init_max=ub)

            # if strategy defaults are not zero, they are likely set to sensible values,
            # so we use them, otherwise we scale the initial params by a factor of 10
            elif self.es_params.init_min == 0 and self.es_params.init_max == 0:
                factor = 10.0
                imin = jnp.full(self.num_dims, self.optimizable.params_0_flat / factor)
                imax = jnp.full(self.num_dims, self.optimizable.params_0_flat * factor)
                self.es_params = self.es_params.replace(init_min=imin, init_max=imax)

        self.key = jr.PRNGKey(
            np.random.randint(0, 2**32, dtype=np.int64) if seed is None else seed
        )

    def optimize(self):
        """Run optimization"""
        fitness_func = jax.jit(self.optimizable.batched_objective_flat)

        state = self.strategy.initialize(self.key, self.es_params)

        # https://github.com/RobertTLange/evosax/issues/45
        state = state.replace(best_fitness=jnp.finfo(jnp.float64).max)

        for gen in range(self.num_generations):
            self.key, subkey = jr.split(self.key)
            x, state = self.strategy.ask(subkey, state, self.es_params)
            fitness = fitness_func(x)
            state = self.strategy.tell(x, fitness, state, self.es_params)

            if self.print_every is not None and (gen + 1) % self.print_every == 0:
                logger.info(
                    "# Gen: %3d|Fitness: %.6f|Params: %s",
                    gen + 1,
                    state.best_fitness,
                    state.best_member,
                )
            if self.metrics_writer is not None:
                self.metrics_writer.write_metrics(best_fitness=state.best_fitness)

        params = state.best_member
        self.optimal_params = self.optimizable.unflatten_params(params)
        if self.optimizable.transformation is not None:
            self.optimal_params = self.optimizable.transformation.inverse_transform(
                self.optimal_params
            )
        return self.optimal_params

optimize()

Run optimization

Source code in collimator/optimization/framework/optimizers_evosax.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def optimize(self):
    """Run optimization"""
    fitness_func = jax.jit(self.optimizable.batched_objective_flat)

    state = self.strategy.initialize(self.key, self.es_params)

    # https://github.com/RobertTLange/evosax/issues/45
    state = state.replace(best_fitness=jnp.finfo(jnp.float64).max)

    for gen in range(self.num_generations):
        self.key, subkey = jr.split(self.key)
        x, state = self.strategy.ask(subkey, state, self.es_params)
        fitness = fitness_func(x)
        state = self.strategy.tell(x, fitness, state, self.es_params)

        if self.print_every is not None and (gen + 1) % self.print_every == 0:
            logger.info(
                "# Gen: %3d|Fitness: %.6f|Params: %s",
                gen + 1,
                state.best_fitness,
                state.best_member,
            )
        if self.metrics_writer is not None:
            self.metrics_writer.write_metrics(best_fitness=state.best_fitness)

    params = state.best_member
    self.optimal_params = self.optimizable.unflatten_params(params)
    if self.optimizable.transformation is not None:
        self.optimal_params = self.optimizable.transformation.inverse_transform(
            self.optimal_params
        )
    return self.optimal_params

IPOPT

Bases: Optimizer

Interior Point Optimizer (IPOPT) for optimization of the objective function with constraints.

Prameters

optimizable (Optimizable): The optimizable object. options (dict): Options for the IPOPT solver. See https://coin-or.github.io/Ipopt/OPTIONS.html

Source code in collimator/optimization/framework/optimizers_ipopt.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class IPOPT(Optimizer):
    """
    Interior Point Optimizer (IPOPT) for optimization of the objective function
    with constraints.

    Prameters:
        optimizable (Optimizable):
            The optimizable object.
        options (dict):
            Options for the IPOPT solver.
            See https://coin-or.github.io/Ipopt/OPTIONS.html
    """

    def __init__(self, optimizable: Optimizable, options: dict = {"disp": 5}):
        self.optimizable = optimizable
        self.options = options
        self.optimal_params = None

    def optimize(self):
        """Run optimization"""
        params = self.optimizable.params_0_flat
        objective = jax.jit(self.optimizable.objective_flat)
        gradient = jax.jit(jax.grad(objective))
        hessian = jax.jit(jax.hessian(objective))

        constraints = jax.jit(self.optimizable.constraints_flat)
        constraints_jac = jax.jit(jax.jacrev(constraints))
        constraints_hessian = jax.jit(jax.hessian(constraints))

        @jax.jit
        def constraints_hessian_vp(x, v):
            return jnp.sum(
                jnp.multiply(v[:, jnp.newaxis, jnp.newaxis], constraints_hessian(x)),
                axis=0,
            )

        constraints_ipopt = [
            {
                "type": "ineq",
                "fun": constraints,
                "jac": constraints_jac,
                "hess": constraints_hessian_vp,
            }
        ]

        # Handle bounds
        bounds = self.optimizable.bounds_flat

        # Jobs from UI would put (-jnp.inf, jnp.inf) as defualt bounds. The user
        # may also have specified bounds this way. IPOPT scipy interface expects `None`
        # to imply unboundedness.
        if bounds is not None:
            bounds = [
                (
                    None if b[0] == -jnp.inf else b[0],
                    None if b[1] == jnp.inf else b[1],
                )
                for b in bounds
            ]

            # Check if all bounds are None, i.e. no bounds at all, and hence
            # algorithms that do not support bounds can be used.
            flattened_bounds = [element for tup in bounds for element in tup]
            all_none = all(element is None for element in flattened_bounds)
            bounds = None if all_none else bounds

        res = cyipopt.minimize_ipopt(
            objective,
            x0=params,
            jac=gradient,
            hess=hessian,
            constraints=constraints_ipopt,
            bounds=bounds,
            options=self.options,
        )

        params = res.x

        self.optimal_params = self.optimizable.unflatten_params(params)
        if self.optimizable.transformation is not None:
            self.optimal_params = self.optimizable.transformation.inverse_transform(
                self.optimal_params
            )
        return self.optimal_params

optimize()

Run optimization

Source code in collimator/optimization/framework/optimizers_ipopt.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def optimize(self):
    """Run optimization"""
    params = self.optimizable.params_0_flat
    objective = jax.jit(self.optimizable.objective_flat)
    gradient = jax.jit(jax.grad(objective))
    hessian = jax.jit(jax.hessian(objective))

    constraints = jax.jit(self.optimizable.constraints_flat)
    constraints_jac = jax.jit(jax.jacrev(constraints))
    constraints_hessian = jax.jit(jax.hessian(constraints))

    @jax.jit
    def constraints_hessian_vp(x, v):
        return jnp.sum(
            jnp.multiply(v[:, jnp.newaxis, jnp.newaxis], constraints_hessian(x)),
            axis=0,
        )

    constraints_ipopt = [
        {
            "type": "ineq",
            "fun": constraints,
            "jac": constraints_jac,
            "hess": constraints_hessian_vp,
        }
    ]

    # Handle bounds
    bounds = self.optimizable.bounds_flat

    # Jobs from UI would put (-jnp.inf, jnp.inf) as defualt bounds. The user
    # may also have specified bounds this way. IPOPT scipy interface expects `None`
    # to imply unboundedness.
    if bounds is not None:
        bounds = [
            (
                None if b[0] == -jnp.inf else b[0],
                None if b[1] == jnp.inf else b[1],
            )
            for b in bounds
        ]

        # Check if all bounds are None, i.e. no bounds at all, and hence
        # algorithms that do not support bounds can be used.
        flattened_bounds = [element for tup in bounds for element in tup]
        all_none = all(element is None for element in flattened_bounds)
        bounds = None if all_none else bounds

    res = cyipopt.minimize_ipopt(
        objective,
        x0=params,
        jac=gradient,
        hess=hessian,
        constraints=constraints_ipopt,
        bounds=bounds,
        options=self.options,
    )

    params = res.x

    self.optimal_params = self.optimizable.unflatten_params(params)
    if self.optimizable.transformation is not None:
        self.optimal_params = self.optimizable.transformation.inverse_transform(
            self.optimal_params
        )
    return self.optimal_params

IdentityTransform

Bases: Transform

A transformation that does nothing: y = x.

Source code in collimator/optimization/framework/base/transformations.py
48
49
50
51
52
53
54
55
56
57
class IdentityTransform(Transform):
    """
    A transformation that does nothing: ``` y = x ```.
    """

    def transform(self, params: dict):
        return params

    def inverse_transform(self, params: dict):
        return params

LogTransform

Bases: Transform

A transformation that applies the natural logarithm to the values of the parameters. y = log(x).

Source code in collimator/optimization/framework/base/transformations.py
60
61
62
63
64
65
66
67
68
69
70
class LogTransform(Transform):
    """
    A transformation that applies the natural logarithm to the values of the parameters.
    ``` y = log(x) ```.
    """

    def transform(self, params: dict):
        return {k: jnp.log(v) for k, v in params.items()}

    def inverse_transform(self, params: dict):
        return {k: jnp.exp(v) for k, v in params.items()}

LogitTransform

Bases: Transform

The logit transformation, defined as: y = log(x / (1 - x))

Source code in collimator/optimization/framework/base/transformations.py
113
114
115
116
117
118
119
120
121
122
123
class LogitTransform(Transform):
    """
    The logit transformation, defined as:
    ``` y = log(x / (1 - x)) ```
    """

    def transform(self, params: dict):
        return {k: jnp.log(v / (1.0 - v)) for k, v in params.items()}

    def inverse_transform(self, params: dict):
        return {k: 1.0 / (1.0 + jnp.exp(-v)) for k, v in params.items()}

NLopt

Bases: Optimizer

Optimizers using the NLopt library.

Parameters:

Name Type Description Default
optimizable Optimizable

The optimizable object.

required
opt_method str

The optimization method to use.

required
ftol_rel float

Relative tolerance on function value.

1e-06
ftol_abs float

Absolute tolerance on function value.

1e-06
xtol_rel float

Relative tolerance on optimization parameters.

1e-06
xtol_abs float

Absolute tolerance on optimization parameters.

1e-06
cons_tol float

Tolerance on constraints.

1e-06
maxeval int

Maximum number of function evaluations.

500
maxtime float

Maximum time in seconds.

0
Source code in collimator/optimization/framework/optimizers_nlopt.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class NLopt(Optimizer):
    """
    Optimizers using the NLopt library.

    Parameters:
        optimizable (Optimizable):
            The optimizable object.
        opt_method (str):
            The optimization method to use.
        ftol_rel (float):
            Relative tolerance on function value.
        ftol_abs (float):
            Absolute tolerance on function value.
        xtol_rel (float):
            Relative tolerance on optimization parameters.
        xtol_abs (float):
            Absolute tolerance on optimization parameters.
        cons_tol (float):
            Tolerance on constraints.
        maxeval (int):
            Maximum number of function evaluations.
        maxtime (float):
            Maximum time in seconds.
    """

    def __init__(
        self,
        optimizable: Optimizable,
        opt_method: str,
        ftol_rel=1e-06,
        ftol_abs=1e-06,
        xtol_rel=1e-06,
        xtol_abs=1e-06,
        cons_tol=1e-06,
        maxeval=500,
        maxtime=0,
    ):
        self.optimizable = optimizable
        self.opt_method = opt_method
        self.ftol_rel = ftol_rel
        self.ftol_abs = ftol_abs
        self.xtol_rel = xtol_rel
        self.xtol_abs = xtol_abs
        self.cons_tol = cons_tol
        self.maxeval = maxeval
        self.maxtime = maxtime
        self.optimal_params = None

    def optimize(self):
        """Run optimization"""
        params = self.optimizable.params_0_flat
        objective = jax.jit(self.optimizable.objective_flat)
        gradient = jax.jit(jax.grad(objective))

        constraints = jax.jit(self.optimizable.constraints_flat)
        constraints_jac = jax.jit(jax.jacrev(constraints))

        def nlopt_obj(x, grad):
            if grad.size > 0:
                grad[:] = gradient(jnp.array(x))
            return float(objective(jnp.array(x)))

        def nlopt_cons(result, x, grad):
            if grad.size > 0:
                grad[:, :] = -constraints_jac(jnp.array(x))
            result[:] = -constraints(jnp.array(x))

        if (
            self.optimizable.bounds_flat is not None
            and self.opt_method not in SUPPORTS_BOUNDS
        ):
            raise ValueError(
                f"Optimization method nlopt:{self.opt_method} does not support bounds."
            )

        if (
            self.optimizable.has_constraints
            and self.opt_method not in SUPPORTS_CONSTRAINTS
        ):
            raise ValueError(
                f"Optimization method nlopt:{self.opt_method} "
                "does not support constraints."
            )

        if self.opt_method not in ALL_METHODS:
            raise ValueError(
                f"Optimization method nlopt:{self.opt_method} is not supported."
            )

        # Initialize nlopt optimizer
        opt_method = ALL_METHODS[self.opt_method]()
        opt = nlopt.opt(opt_method, len(params))

        # Set the objective function
        opt.set_min_objective(nlopt_obj)

        # Set the constraints
        if self.optimizable.has_constraints:
            num_constraints = self.optimizable.constraints_flat(jnp.array(params)).size
            opt.add_inequality_mconstraint(
                nlopt_cons, [self.cons_tol] * num_constraints
            )

        # Set the bounds
        if self.optimizable.bounds_flat is not None:
            lower_bounds, upper_bounds = zip(*self.optimizable.bounds_flat)
            opt.set_lower_bounds(lower_bounds)
            opt.set_upper_bounds(upper_bounds)

        # Set stopping criteria
        opt.set_ftol_rel(self.ftol_rel)
        opt.set_ftol_abs(self.ftol_abs)
        opt.set_xtol_rel(self.xtol_rel)
        opt.set_xtol_abs(self.xtol_abs)
        opt.set_maxeval(self.maxeval)
        opt.set_maxtime(self.maxtime)

        # Run the optimization
        params = opt.optimize(params)

        self.optimal_params = self.optimizable.unflatten_params(params)
        if self.optimizable.transformation is not None:
            self.optimal_params = self.optimizable.transformation.inverse_transform(
                self.optimal_params
            )
        return self.optimal_params

optimize()

Run optimization

Source code in collimator/optimization/framework/optimizers_nlopt.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def optimize(self):
    """Run optimization"""
    params = self.optimizable.params_0_flat
    objective = jax.jit(self.optimizable.objective_flat)
    gradient = jax.jit(jax.grad(objective))

    constraints = jax.jit(self.optimizable.constraints_flat)
    constraints_jac = jax.jit(jax.jacrev(constraints))

    def nlopt_obj(x, grad):
        if grad.size > 0:
            grad[:] = gradient(jnp.array(x))
        return float(objective(jnp.array(x)))

    def nlopt_cons(result, x, grad):
        if grad.size > 0:
            grad[:, :] = -constraints_jac(jnp.array(x))
        result[:] = -constraints(jnp.array(x))

    if (
        self.optimizable.bounds_flat is not None
        and self.opt_method not in SUPPORTS_BOUNDS
    ):
        raise ValueError(
            f"Optimization method nlopt:{self.opt_method} does not support bounds."
        )

    if (
        self.optimizable.has_constraints
        and self.opt_method not in SUPPORTS_CONSTRAINTS
    ):
        raise ValueError(
            f"Optimization method nlopt:{self.opt_method} "
            "does not support constraints."
        )

    if self.opt_method not in ALL_METHODS:
        raise ValueError(
            f"Optimization method nlopt:{self.opt_method} is not supported."
        )

    # Initialize nlopt optimizer
    opt_method = ALL_METHODS[self.opt_method]()
    opt = nlopt.opt(opt_method, len(params))

    # Set the objective function
    opt.set_min_objective(nlopt_obj)

    # Set the constraints
    if self.optimizable.has_constraints:
        num_constraints = self.optimizable.constraints_flat(jnp.array(params)).size
        opt.add_inequality_mconstraint(
            nlopt_cons, [self.cons_tol] * num_constraints
        )

    # Set the bounds
    if self.optimizable.bounds_flat is not None:
        lower_bounds, upper_bounds = zip(*self.optimizable.bounds_flat)
        opt.set_lower_bounds(lower_bounds)
        opt.set_upper_bounds(upper_bounds)

    # Set stopping criteria
    opt.set_ftol_rel(self.ftol_rel)
    opt.set_ftol_abs(self.ftol_abs)
    opt.set_xtol_rel(self.xtol_rel)
    opt.set_xtol_abs(self.xtol_abs)
    opt.set_maxeval(self.maxeval)
    opt.set_maxtime(self.maxtime)

    # Run the optimization
    params = opt.optimize(params)

    self.optimal_params = self.optimizable.unflatten_params(params)
    if self.optimizable.transformation is not None:
        self.optimal_params = self.optimizable.transformation.inverse_transform(
            self.optimal_params
        )
    return self.optimal_params

NegativeNegativeLogTransform

Bases: Transform

A transformation that applies the negative of the natural logarithm of the negative of the values of the parameters. y = -log(-x)

Source code in collimator/optimization/framework/base/transformations.py
73
74
75
76
77
78
79
80
81
82
83
84
class NegativeNegativeLogTransform(Transform):
    """
    A transformation that applies the negative of the natural logarithm of the negative
    of the values of the parameters.
    ``` y = -log(-x) ```
    """

    def transform(self, params: dict):
        return {k: -jnp.log(-v) for k, v in params.items()}

    def inverse_transform(self, params: dict):
        return {k: -jnp.exp(-v) for k, v in params.items()}

NormalizeTransform

Bases: Transform

A transformation that normalizes the values of the parameters to the range [0, 1]. y = (x - min) / (max - min) Paramteters: - params_min: dict with the minimum values for each parameter. - params_max: dict with the maximum values for each parameter.

Source code in collimator/optimization/framework/base/transformations.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class NormalizeTransform(Transform):
    """
    A transformation that normalizes the values of the parameters to the range [0, 1].
    ``` y = (x - min) / (max - min) ```
    Paramteters:
        - params_min: dict with the minimum values for each parameter.
        - params_max: dict with the maximum values for each parameter.
    """

    def __init__(self, params_min: dict, params_max: dict):
        self.params_min = params_min
        self.params_max = params_max

    def transform(self, params: dict):
        return {
            k: (v - self.params_min[k]) / (self.params_max[k] - self.params_min[k])
            for k, v in params.items()
        }

    def inverse_transform(self, params: dict):
        return {
            k: v * (self.params_max[k] - self.params_min[k]) + self.params_min[k]
            for k, v in params.items()
        }

Optax

Bases: Optimizer

Optax optimizer without support for stochastic variables.

Paramters

optimizable (Optimizable): The optimizable object. opt_method (str): The optimization method to use. learning_rate (float): The learning rate. opt_method_config (dict): Configuration for the optimization method. num_epochs (int): The number of epochs. clip_range (tuple): The range to clip the gradients. print_every (int): Print progress every print_every epochs. metrics_writer (MetricsWriter|None): Optional CSV file to write metrics to.

Source code in collimator/optimization/framework/optimizers_optax.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
class Optax(Optimizer):
    """
    Optax optimizer without support for stochastic variables.

    Paramters:
        optimizable (Optimizable):
            The optimizable object.
        opt_method (str):
            The optimization method to use.
        learning_rate (float):
            The learning rate.
        opt_method_config (dict):
            Configuration for the optimization method.
        num_epochs (int):
            The number of epochs.
        clip_range (tuple):
            The range to clip the gradients.
        print_every (int):
            Print progress every `print_every` epochs.
        metrics_writer (MetricsWriter|None):
            Optional CSV file to write metrics to.
    """

    def __init__(
        self,
        optimizable: Optimizable,
        opt_method,
        learning_rate,
        opt_method_config,
        num_epochs=100,
        clip_range=None,
        print_every=None,
        metrics_writer: MetricsWriter = None,
    ):
        self.optimizable = optimizable
        self.opt_method = opt_method
        self.num_epochs = num_epochs
        self.clip_range = clip_range
        self.print_every = print_every
        self.metrics_writer = metrics_writer
        self.optimal_params = None
        self.losses = []

        if self.optimizable.bounds_flat is not None:
            # Jobs from UI would put (-jnp.inf, jnp.inf) as defualt bounds. The user
            # may also have specified bounds this way.
            bounds = [
                (
                    None if b[0] == -jnp.inf else b[0],
                    None if b[1] == jnp.inf else b[1],
                )
                for b in self.optimizable.bounds_flat
            ]

            # Check if all bounds are None, i.e. no bounds at all, and hence Optax
            # algorithms which don't natively support bounds can be used
            flattened_bounds = [element for tup in bounds for element in tup]
            all_none = all(element is None for element in flattened_bounds)

            if not all_none:
                raise ValueError(
                    f"Optimization method {opt_method} does not support bounds."
                )

        if self.optimizable.has_constraints:
            raise ValueError(
                f"Optimization method optax:{self.opt_method} "
                "does not support constraints."
            )

        opt_func = getattr(optax, opt_method)

        # Instantiate the optimizer with validated config
        valid_opts = _remap_and_filter_valid_params(opt_func, opt_method_config)
        self.optimizer = opt_func(learning_rate, **valid_opts)

    @partial(jax.jit, static_argnums=(0,))
    def step(self, params, opt_state):
        """Take a single optimization step"""
        loss, grads = jax.value_and_grad(self.optimizable.objective_flat)(params)
        grads = jnp.clip(grads, *self.clip_range) if self.clip_range else grads
        updates, opt_state = self.optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    def optimize(self) -> dict[str, ArrayLike]:
        """Run optimization"""
        params = self.optimizable.params_0_flat
        opt_state = self.optimizer.init(params)

        for epoch in range(self.num_epochs):
            params, opt_state, loss = self.step(params, opt_state)
            self.losses.append(jnp.mean(loss))
            if self.print_every and epoch % self.print_every == 0:
                p: dict = self.optimizable.unflatten_params(params)
                if self.optimizable.transformation is not None:
                    p = self.optimizable.transformation.inverse_transform(p)
                p = {k: v.tolist() for k, v in p.items()}
                logger.info(
                    "Epoch %s, loss: %s", epoch, jnp.mean(loss), **logdata(params=p)
                )
            if self.metrics_writer:
                self.metrics_writer.write_metrics(loss=self.losses[-1])

        self.optimal_params = self.optimizable.unflatten_params(params)
        if self.optimizable.transformation is not None:
            self.optimal_params = self.optimizable.transformation.inverse_transform(
                self.optimal_params
            )
        return self.optimal_params

    @property
    def metrics(self):
        return {"loss": self.losses}

optimize()

Run optimization

Source code in collimator/optimization/framework/optimizers_optax.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def optimize(self) -> dict[str, ArrayLike]:
    """Run optimization"""
    params = self.optimizable.params_0_flat
    opt_state = self.optimizer.init(params)

    for epoch in range(self.num_epochs):
        params, opt_state, loss = self.step(params, opt_state)
        self.losses.append(jnp.mean(loss))
        if self.print_every and epoch % self.print_every == 0:
            p: dict = self.optimizable.unflatten_params(params)
            if self.optimizable.transformation is not None:
                p = self.optimizable.transformation.inverse_transform(p)
            p = {k: v.tolist() for k, v in p.items()}
            logger.info(
                "Epoch %s, loss: %s", epoch, jnp.mean(loss), **logdata(params=p)
            )
        if self.metrics_writer:
            self.metrics_writer.write_metrics(loss=self.losses[-1])

    self.optimal_params = self.optimizable.unflatten_params(params)
    if self.optimizable.transformation is not None:
        self.optimal_params = self.optimizable.transformation.inverse_transform(
            self.optimal_params
        )
    return self.optimal_params

step(params, opt_state)

Take a single optimization step

Source code in collimator/optimization/framework/optimizers_optax.py
299
300
301
302
303
304
305
306
@partial(jax.jit, static_argnums=(0,))
def step(self, params, opt_state):
    """Take a single optimization step"""
    loss, grads = jax.value_and_grad(self.optimizable.objective_flat)(params)
    grads = jnp.clip(grads, *self.clip_range) if self.clip_range else grads
    updates, opt_state = self.optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

OptaxWithStochasticVars

Bases: Optimizer

Optax optimizer with support for stochastic variables.

Parameters:

Name Type Description Default
optimizable OptimizableWithStochasticVars

The optimizable object.

required
opt_method str

The optimization method to use.

required
learning_rate float

The learning rate.

required
opt_method_config dict

Configuration for the optimization method.

required
num_epochs int

The number of epochs.

100
batch_size int

The batch size.

1
num_batches int

The number of batches.

1
clip_range tuple

The range to clip the gradients.

None
print_every int

Print progress every print_every epochs.

None
metrics_writer MetricsWriter | None

Optional CSV file to write metrics to.

None
Source code in collimator/optimization/framework/optimizers_optax.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class OptaxWithStochasticVars(Optimizer):
    """
    Optax optimizer with support for stochastic variables.

    Parameters:
        optimizable (OptimizableWithStochasticVars):
            The optimizable object.
        opt_method (str):
            The optimization method to use.
        learning_rate (float):
            The learning rate.
        opt_method_config (dict):
            Configuration for the optimization method.
        num_epochs (int):
            The number of epochs.
        batch_size (int):
            The batch size.
        num_batches (int):
            The number of batches.
        clip_range (tuple):
            The range to clip the gradients.
        print_every (int):
            Print progress every `print_every` epochs.
        metrics_writer (MetricsWriter|None):
            Optional CSV file to write metrics to.
    """

    def __init__(
        self,
        optimizable: OptimizableWithStochasticVars,
        opt_method: str,
        learning_rate,
        opt_method_config,
        num_epochs=100,
        batch_size=1,
        num_batches=1,
        clip_range=None,
        print_every=None,
        metrics_writer: MetricsWriter = None,
    ):
        self.optimizable = optimizable
        self.opt_method = opt_method
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.num_batches = num_batches
        self.clip_range = clip_range
        self.print_every = print_every
        self.metrics_writer = metrics_writer
        self.optimal_params = None
        self.losses = []

        if optimizable.bounds_flat is not None:
            # Jobs from UI would put (-jnp.inf, jnp.inf) as defualt bounds. The user
            # may also have specified bounds this way.
            bounds = [
                (
                    None if b[0] == -jnp.inf else b[0],
                    None if b[1] == jnp.inf else b[1],
                )
                for b in self.optimizable.bounds_flat
            ]

            # Check if all bounds are None, i.e. no bounds at all, and hence Optax
            # algorithms which don't natively support bounds can be used
            flattened_bounds = [element for tup in bounds for element in tup]
            all_none = all(element is None for element in flattened_bounds)

            if not all_none:
                raise ValueError(
                    f"Optimization method {opt_method} does not support bounds."
                )

        opt_func = getattr(optax, opt_method, None)
        if opt_func is None:
            raise ValueError(f"Unknown optax optimizer: {opt_method}")

        # Instantiate the optimizer with validated config
        valid_opts = _remap_and_filter_valid_params(opt_func, opt_method_config)
        self.optimizer = opt_func(learning_rate, **valid_opts)

    def batched_objective_flat(self, params, stochastic_vars_batch_flat):
        """Mean of the objective function over a batch"""
        return jnp.mean(
            self.optimizable.batched_objective_flat(params, stochastic_vars_batch_flat)
        )

    @partial(jax.jit, static_argnums=(0,))
    def step(self, params, opt_state, stochastic_vars_batch):
        """Take a single optimization step over one batch"""
        batch_loss, grads = jax.value_and_grad(self.batched_objective_flat)(
            params, stochastic_vars_batch
        )

        grads = jnp.clip(grads, *self.clip_range) if self.clip_range else grads

        updates, opt_state = self.optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, batch_loss

    def optimize(self):
        """Run optimization"""
        params = self.optimizable.params_0_flat
        opt_state = self.optimizer.init(params)

        if self.num_batches * self.batch_size == 1:
            # don't randomize over stochastic variables; use a single
            # batch of size 1 with initial stochastic variables
            data_flat, _ = ravel_pytree(self.optimizable.vars_0)
            stochastic_vars_training_data_flat = data_flat[None, None, :]
        else:
            _, stochastic_vars_training_data_flat = self.optimizable.sample_random_vars(
                self.num_batches * self.batch_size
            )

        @jax.jit
        def _scan_fun(carry, stochastic_vars_batch):
            params, opt_state = carry
            params, opt_state, batch_loss = self.step(
                params, opt_state, stochastic_vars_batch
            )
            return (params, opt_state), batch_loss

        for epoch in range(self.num_epochs):
            if self.num_batches * self.batch_size == 1:
                stochastic_vars_batches = stochastic_vars_training_data_flat
            else:
                stochastic_vars_batches = self.optimizable.generate_batches(
                    stochastic_vars_training_data_flat,
                    self.num_batches,
                    self.batch_size,
                )
            (params, opt_state), batch_losses = lax.scan(
                _scan_fun, (params, opt_state), stochastic_vars_batches
            )

            self.losses.append(jnp.mean(batch_losses))
            if self.print_every and epoch % self.print_every == 0:
                p: dict = self.optimizable.unflatten_params(params)
                if self.optimizable.transformation is not None:
                    p = self.optimizable.transformation.inverse_transform(p)
                p = {k: v.tolist() for k, v in p.items()}
                logger.info(
                    "Epoch %s, average batch loss: %s",
                    epoch,
                    jnp.mean(batch_losses),
                    **logdata(params=p),
                )
            if self.metrics_writer is not None:
                self.metrics_writer.write_metrics(loss=self.losses[-1])

        self.optimal_params = self.optimizable.unflatten_params(params)
        if self.optimizable.transformation is not None:
            self.optimal_params = self.optimizable.transformation.inverse_transform(
                self.optimal_params
            )
        return self.optimal_params

    @property
    def metrics(self):
        return {"loss": self.losses}

batched_objective_flat(params, stochastic_vars_batch_flat)

Mean of the objective function over a batch

Source code in collimator/optimization/framework/optimizers_optax.py
141
142
143
144
145
def batched_objective_flat(self, params, stochastic_vars_batch_flat):
    """Mean of the objective function over a batch"""
    return jnp.mean(
        self.optimizable.batched_objective_flat(params, stochastic_vars_batch_flat)
    )

optimize()

Run optimization

Source code in collimator/optimization/framework/optimizers_optax.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def optimize(self):
    """Run optimization"""
    params = self.optimizable.params_0_flat
    opt_state = self.optimizer.init(params)

    if self.num_batches * self.batch_size == 1:
        # don't randomize over stochastic variables; use a single
        # batch of size 1 with initial stochastic variables
        data_flat, _ = ravel_pytree(self.optimizable.vars_0)
        stochastic_vars_training_data_flat = data_flat[None, None, :]
    else:
        _, stochastic_vars_training_data_flat = self.optimizable.sample_random_vars(
            self.num_batches * self.batch_size
        )

    @jax.jit
    def _scan_fun(carry, stochastic_vars_batch):
        params, opt_state = carry
        params, opt_state, batch_loss = self.step(
            params, opt_state, stochastic_vars_batch
        )
        return (params, opt_state), batch_loss

    for epoch in range(self.num_epochs):
        if self.num_batches * self.batch_size == 1:
            stochastic_vars_batches = stochastic_vars_training_data_flat
        else:
            stochastic_vars_batches = self.optimizable.generate_batches(
                stochastic_vars_training_data_flat,
                self.num_batches,
                self.batch_size,
            )
        (params, opt_state), batch_losses = lax.scan(
            _scan_fun, (params, opt_state), stochastic_vars_batches
        )

        self.losses.append(jnp.mean(batch_losses))
        if self.print_every and epoch % self.print_every == 0:
            p: dict = self.optimizable.unflatten_params(params)
            if self.optimizable.transformation is not None:
                p = self.optimizable.transformation.inverse_transform(p)
            p = {k: v.tolist() for k, v in p.items()}
            logger.info(
                "Epoch %s, average batch loss: %s",
                epoch,
                jnp.mean(batch_losses),
                **logdata(params=p),
            )
        if self.metrics_writer is not None:
            self.metrics_writer.write_metrics(loss=self.losses[-1])

    self.optimal_params = self.optimizable.unflatten_params(params)
    if self.optimizable.transformation is not None:
        self.optimal_params = self.optimizable.transformation.inverse_transform(
            self.optimal_params
        )
    return self.optimal_params

step(params, opt_state, stochastic_vars_batch)

Take a single optimization step over one batch

Source code in collimator/optimization/framework/optimizers_optax.py
147
148
149
150
151
152
153
154
155
156
157
158
@partial(jax.jit, static_argnums=(0,))
def step(self, params, opt_state, stochastic_vars_batch):
    """Take a single optimization step over one batch"""
    batch_loss, grads = jax.value_and_grad(self.batched_objective_flat)(
        params, stochastic_vars_batch
    )

    grads = jnp.clip(grads, *self.clip_range) if self.clip_range else grads

    updates, opt_state = self.optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, batch_loss

Optimizable

Bases: OptimizableBase

Base class for all optimizables with no stochastic variables.

For parameters, see OptimizableBase.

The abstract method prepare_context should update the context to incorporate the optimization parameters.

This classs creates methods for evaluation of the objective and constraints from the concrete implementation of the abstract methods. This class also creates methods for batched evaluation of the objective and constraints, which are useful for optimizers that can work with batches (eg. Optax), and population-based optimizers.

Source code in collimator/optimization/framework/base/optimizable.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
class Optimizable(OptimizableBase):
    """
    Base class for all optimizables with no stochastic variables.

    For parameters, see `OptimizableBase`.

    The abstract method `prepare_context` should update the context to incorporate the
    optimization parameters.

    This classs creates methods for evaluation of the objective and constraints from the
    concrete implementation of the abstract methods. This class also creates methods for
    batched evaluation of the objective and constraints, which are useful for optimizers
    that can work with batches (eg. Optax), and population-based optimizers.
    """

    def __init__(
        self,
        diagram,
        base_context,
        sim_t_span=(0.0, 1.0),
        params_0=None,
        bounds=None,
        transformation=None,
        init_min_max=None,
        seed=None,
    ):
        super().__init__(
            diagram,
            base_context,
            sim_t_span,
            params_0,
            bounds,
            transformation,
            init_min_max,
            seed,
        )
        self.batched_objective = jax.jit(jax.vmap(self.objective, in_axes=(0,)))
        self.batched_objective_flat = jax.jit(
            jax.vmap(self.objective_flat, in_axes=(0,))
        )

        self.batched_constraints = jax.jit(jax.vmap(self.constraints, in_axes=(0,)))
        self.batched_constraints_flat = jax.jit(
            jax.vmap(self.constraints_flat, in_axes=(0,))
        )

    @abstractmethod
    def prepare_context(self, context, params: dict):
        """
        Model-specific updates to incorporate the sample data and parameters.
        Return the updated context.
        """
        pass

    def run_simulation(self, params: dict):
        """
        Run simulation and return final results context.
        """
        context = self.base_context.with_time(self.start_time)
        context = self.prepare_context(context, params)
        results = self.simulator.advance_to(self.stop_time, context)
        return results.context

    def objective_flat(self, params: Array):
        """Objective function for optimization with flattened parameters input"""
        return self.objective(self.unflatten_params(jnp.atleast_1d(params)))

    def objective(self, params: dict):
        """Objective function for optimization with dict parameters input"""
        if self.transformation is not None:
            params = self.transformation.inverse_transform(params)
        results_context = self.run_simulation(params)
        return self.objective_from_context(results_context)

    def constraints_flat(self, params: Array):
        """Constraints function for optimization with flattened parameters input"""
        return self.constraints(self.unflatten_params(jnp.atleast_1d(params)))

    def constraints(self, params: dict):
        """Constraints function for optimization with dict parameters input"""
        if self.transformation is not None:
            params = self.transformation.inverse_transform(params)
        results_context = self.run_simulation(params)
        return self.constraints_from_context(results_context)

constraints(params)

Constraints function for optimization with dict parameters input

Source code in collimator/optimization/framework/base/optimizable.py
317
318
319
320
321
322
def constraints(self, params: dict):
    """Constraints function for optimization with dict parameters input"""
    if self.transformation is not None:
        params = self.transformation.inverse_transform(params)
    results_context = self.run_simulation(params)
    return self.constraints_from_context(results_context)

constraints_flat(params)

Constraints function for optimization with flattened parameters input

Source code in collimator/optimization/framework/base/optimizable.py
313
314
315
def constraints_flat(self, params: Array):
    """Constraints function for optimization with flattened parameters input"""
    return self.constraints(self.unflatten_params(jnp.atleast_1d(params)))

objective(params)

Objective function for optimization with dict parameters input

Source code in collimator/optimization/framework/base/optimizable.py
306
307
308
309
310
311
def objective(self, params: dict):
    """Objective function for optimization with dict parameters input"""
    if self.transformation is not None:
        params = self.transformation.inverse_transform(params)
    results_context = self.run_simulation(params)
    return self.objective_from_context(results_context)

objective_flat(params)

Objective function for optimization with flattened parameters input

Source code in collimator/optimization/framework/base/optimizable.py
302
303
304
def objective_flat(self, params: Array):
    """Objective function for optimization with flattened parameters input"""
    return self.objective(self.unflatten_params(jnp.atleast_1d(params)))

prepare_context(context, params) abstractmethod

Model-specific updates to incorporate the sample data and parameters. Return the updated context.

Source code in collimator/optimization/framework/base/optimizable.py
285
286
287
288
289
290
291
@abstractmethod
def prepare_context(self, context, params: dict):
    """
    Model-specific updates to incorporate the sample data and parameters.
    Return the updated context.
    """
    pass

run_simulation(params)

Run simulation and return final results context.

Source code in collimator/optimization/framework/base/optimizable.py
293
294
295
296
297
298
299
300
def run_simulation(self, params: dict):
    """
    Run simulation and return final results context.
    """
    context = self.base_context.with_time(self.start_time)
    context = self.prepare_context(context, params)
    results = self.simulator.advance_to(self.stop_time, context)
    return results.context

OptimizableWithStochasticVars

Bases: OptimizableBase

Base class for all optimizables with stochastic variables. This is designed only for Optax optimizers and without constraints. Other optimizers are unlikely to work well with stochastic variables.

This class is similar to Optimizable with the key difference that both params and vars (stochastic variables) need to be updated as opposed to params alone

Parameters:

Name Type Description Default
vars_0

dict Initial stochastic variable values. If not provided, the stochastic_vars method will be used to extract these from the base context.

None
distribution_config_vars

DistributionConfig Configuration for stochastic variables. If not provided, standard normal distribution is used.

None
Source code in collimator/optimization/framework/base/optimizable.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
class OptimizableWithStochasticVars(OptimizableBase):
    """
    Base class for all optimizables with stochastic variables. This is designed
    only for Optax optimizers and without constraints. Other optimizers are unlikely to
    work well with stochastic variables.

    This class is similar to `Optimizable` with the key difference that both `params`
    and `vars` (stochastic variables) need to be updated as opposed to `params` alone

    Parameters:
        vars_0: dict
            Initial stochastic variable values. If not provided, the
            `stochastic_vars` method will be used to extract these from the
            base context.
        distribution_config_vars: DistributionConfig
            Configuration for stochastic variables. If not provided, standard normal
            distribution is used.
    """

    def __init__(
        self,
        diagram,
        base_context,
        sim_t_span=(0.0, 1.0),
        params_0=None,
        vars_0=None,
        distribution_config_vars=None,
        bounds=None,
        transformation=None,
        seed=None,
    ):
        super().__init__(
            diagram,
            base_context,
            sim_t_span,
            params_0,
            bounds,
            transformation,
            init_min_max=None,
            seed=seed,
        )

        if vars_0 is None:
            self.vars_0 = self.stochastic_vars(base_context)
        else:
            self.vars_0 = vars_0

        self.vars_0_flat, self.unflatten_vars = ravel_pytree(self.vars_0)
        self.num_stochastic_vars = self.vars_0_flat.size

        self.batched_objective = jax.jit(jax.vmap(self.objective, in_axes=(None, 0)))
        self.batched_objective_flat = jax.jit(
            jax.vmap(self.objective_flat, in_axes=(None, 0))
        )

        if distribution_config_vars is None:
            logger.warning(
                "`distribution_config_vars` is not specified. Using standard normal "
                "as the default distribution"
            )
            self.distribution_config_vars = DistributionConfig(
                names=list(self.vars_0.keys()),
                shapes=[jnp.shape(x) for x in self.vars_0.values()],
                distributions=["normal"] * len(self.vars_0),
                distributions_configs=[{}] * len(self.vars_0),
            )
        else:
            self.distribution_config_vars = distribution_config_vars

    @abstractmethod
    def prepare_context(self, context, params: dict, vars: dict):
        """
        Model-specific updates to incorporate the parameters and stochastic vars.
        Return the updated context.
        """
        pass

    @abstractmethod
    def stochastic_vars(self, context) -> dict:
        """
        Extract stochastic `vars` from the context.
        These should be in the form of a dict of Pytrees.
        """
        pass

    def run_simulation(self, params: dict, vars: dict):
        """Run simulation and return final results context."""
        context = self.base_context.with_time(self.start_time)
        context = self.prepare_context(context, params, vars)
        results = self.simulator.advance_to(self.stop_time, context)
        return results.context

    def objective_flat(self, params: Array, vars: Array):
        """Objective function for optimization with flattened parameters and vars
        input"""
        return self.objective(
            self.unflatten_params(jnp.atleast_1d(params)),
            self.unflatten_vars(jnp.atleast_1d(vars)),
        )

    def objective(self, params: dict, vars: dict):
        """Objective function for optimization with dict parameters and vars input"""
        if self.transformation is not None:
            params = self.transformation.inverse_transform(params)
        results_context = self.run_simulation(params, vars)
        return self.objective_from_context(results_context)

    def sample_random_vars(self, num_samples):
        """Generate random samples of the stochastic variables"""
        names = self.distribution_config_vars.names
        shapes = self.distribution_config_vars.shapes
        distributions = self.distribution_config_vars.distributions
        distributions_configs = self.distribution_config_vars.distributions_configs
        data, flat_data = self._generate_random_data(
            names,
            shapes,
            distributions,
            distributions_configs,
            num_samples,
        )
        return data, flat_data

    def generate_batches(
        self,
        data,
        num_batches,
        batch_size,
    ):
        """
        Given all samples `data`, generate `num_batches` random batches of size
        `batch_size` each
        """
        num_samples = data.shape[0]
        self.key, subkey = jr.split(self.key)
        batch_indices = jax.random.choice(
            subkey, num_samples, (num_batches, batch_size), replace=True
        )
        batches = data[batch_indices]
        return batches

    @staticmethod
    def _distribution(name: str, key, shape, options: dict):
        # remap options from json names to jax.random names
        # FIXME: code users should be able to pass default argnames (i.e. those used
        # by jax.random), for example, "minval" and "maxval" for uniform distribution
        if name == "normal":
            mean = options.get("mean", 0.0)
            std_dev = options.get("std_dev", 1.0)
            return jr.normal(key, shape) * std_dev + mean

        if name == "lognormal":
            mean = options.get("mean", 0.0)
            sigma = options.get("std_dev", 1.0)
            return jr.lognormal(key, sigma=sigma, shape=shape) + mean

        if name == "uniform":
            minval = options.get("min", 0.0)
            maxval = options.get("max", 1.0)
            return jr.uniform(key, shape=shape, minval=minval, maxval=maxval)

        warnings.warn(f"Unknown distribution: {name}.")
        sample_func = getattr(jr, name)
        return sample_func(key, shape, **options)

    def _generate_random_data(
        self,
        names,
        shapes,
        distributions,
        distributions_configs,
        num_samples,
    ):
        data = {}
        self.key, *subkeys = jr.split(self.key, len(names) + 1)
        for key, name, shape, distribution, distribution_config in zip(
            subkeys, names, shapes, distributions, distributions_configs
        ):
            data[name] = self._distribution(
                distribution,
                key,
                (num_samples, *shape),
                distribution_config,
            )

        def _flatten(x):
            x_flat, _ = ravel_pytree(x)
            return x_flat

        return data, _flatten(data)

generate_batches(data, num_batches, batch_size)

Given all samples data, generate num_batches random batches of size batch_size each

Source code in collimator/optimization/framework/base/optimizable.py
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def generate_batches(
    self,
    data,
    num_batches,
    batch_size,
):
    """
    Given all samples `data`, generate `num_batches` random batches of size
    `batch_size` each
    """
    num_samples = data.shape[0]
    self.key, subkey = jr.split(self.key)
    batch_indices = jax.random.choice(
        subkey, num_samples, (num_batches, batch_size), replace=True
    )
    batches = data[batch_indices]
    return batches

objective(params, vars)

Objective function for optimization with dict parameters and vars input

Source code in collimator/optimization/framework/base/optimizable.py
425
426
427
428
429
430
def objective(self, params: dict, vars: dict):
    """Objective function for optimization with dict parameters and vars input"""
    if self.transformation is not None:
        params = self.transformation.inverse_transform(params)
    results_context = self.run_simulation(params, vars)
    return self.objective_from_context(results_context)

objective_flat(params, vars)

Objective function for optimization with flattened parameters and vars input

Source code in collimator/optimization/framework/base/optimizable.py
417
418
419
420
421
422
423
def objective_flat(self, params: Array, vars: Array):
    """Objective function for optimization with flattened parameters and vars
    input"""
    return self.objective(
        self.unflatten_params(jnp.atleast_1d(params)),
        self.unflatten_vars(jnp.atleast_1d(vars)),
    )

prepare_context(context, params, vars) abstractmethod

Model-specific updates to incorporate the parameters and stochastic vars. Return the updated context.

Source code in collimator/optimization/framework/base/optimizable.py
394
395
396
397
398
399
400
@abstractmethod
def prepare_context(self, context, params: dict, vars: dict):
    """
    Model-specific updates to incorporate the parameters and stochastic vars.
    Return the updated context.
    """
    pass

run_simulation(params, vars)

Run simulation and return final results context.

Source code in collimator/optimization/framework/base/optimizable.py
410
411
412
413
414
415
def run_simulation(self, params: dict, vars: dict):
    """Run simulation and return final results context."""
    context = self.base_context.with_time(self.start_time)
    context = self.prepare_context(context, params, vars)
    results = self.simulator.advance_to(self.stop_time, context)
    return results.context

sample_random_vars(num_samples)

Generate random samples of the stochastic variables

Source code in collimator/optimization/framework/base/optimizable.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def sample_random_vars(self, num_samples):
    """Generate random samples of the stochastic variables"""
    names = self.distribution_config_vars.names
    shapes = self.distribution_config_vars.shapes
    distributions = self.distribution_config_vars.distributions
    distributions_configs = self.distribution_config_vars.distributions_configs
    data, flat_data = self._generate_random_data(
        names,
        shapes,
        distributions,
        distributions_configs,
        num_samples,
    )
    return data, flat_data

stochastic_vars(context) abstractmethod

Extract stochastic vars from the context. These should be in the form of a dict of Pytrees.

Source code in collimator/optimization/framework/base/optimizable.py
402
403
404
405
406
407
408
@abstractmethod
def stochastic_vars(self, context) -> dict:
    """
    Extract stochastic `vars` from the context.
    These should be in the form of a dict of Pytrees.
    """
    pass

Scipy

Bases: Optimizer

Scipy/JAX-scipy optimizers.

Parameters:

Name Type Description Default
optimizable Optimizable

The optimizable object.

required
opt_method str

The optimization method to use.

required
tol float

Tolerance for termination. For detailed control, use opt_method_config.

None
opt_method_config dict

Configuration for the optimization method.

None
use_autodiff_grad bool

Whether to use autodiff for gradient computation.

True
use_jax_scipy bool

Whether to use JAX's version of optimize.minimize.

False
Source code in collimator/optimization/framework/optimizers_scipy.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
class Scipy(Optimizer):
    """
    Scipy/JAX-scipy optimizers.

    Parameters:
        optimizable (Optimizable):
            The optimizable object.
        opt_method (str):
            The optimization method to use.
        tol (float):
            Tolerance for termination. For detailed control, use `opt_method_config`.
        opt_method_config (dict):
            Configuration for the optimization method.
        use_autodiff_grad (bool):
            Whether to use autodiff for gradient computation.
        use_jax_scipy (bool):
            Whether to use JAX's version of `optimize.minimize`.
    """

    def __init__(
        self,
        optimizable: Optimizable,
        opt_method,
        tol=None,
        opt_method_config=None,
        use_autodiff_grad=True,
        use_jax_scipy=False,
        metrics_writer: MetricsWriter = None,
    ):
        self.optimizable = optimizable
        self.opt_method = opt_method
        self.tol = tol
        self.opt_method_config = opt_method_config or {}
        self.use_autodiff_grad = use_autodiff_grad
        self.use_jax_scipy = use_jax_scipy
        self.optimal_params = None
        self.metrics_writer = metrics_writer

    def optimize(self):
        """Run optimization"""
        params = self.optimizable.params_0_flat
        objective = jax.jit(self.optimizable.objective_flat)

        if self.use_jax_scipy:
            warnings.warn(
                "`use_jax_scipy` is True. JAX's version of optimize.minimize will be "
                "used. Consequently, `opt_method` will be set of `BFGS` and autodiff "
                "will be used for gradient computation. Constraints and bounds will "
                "be ignored. If you want to use scipy's version of minimize, set "
                " `use_jax_scipy` to False."
            )
            opt_res = jaxopt.minimize(
                objective,
                params,
                method="BFGS",
                tol=self.tol,
                options=self.opt_method_config,
            )
            params = opt_res.x

        else:
            use_jac = False
            if self.opt_method in ACCEPTS_GRAD and self.use_autodiff_grad:
                jac = jax.jit(jax.grad(objective))
                use_jac = True

            # Handle bounds
            bounds = self.optimizable.bounds_flat

            # Jobs from UI would put (-jnp.inf, jnp.inf) as defualt bounds. The user
            # may also have specified bounds this way. Scipy expects `None` to imply
            # unboundedness.
            if bounds is not None:
                bounds = [
                    (
                        None if b[0] == -jnp.inf else b[0],
                        None if b[1] == jnp.inf else b[1],
                    )
                    for b in bounds
                ]

                # Check if all bounds are None, i.e. no bounds at all, and hence
                # algorithms that do not support bounds can be used.
                flattened_bounds = [element for tup in bounds for element in tup]
                all_none = all(element is None for element in flattened_bounds)
                bounds = None if all_none else bounds

            if bounds is not None and self.opt_method not in SUPPORTS_BOUNDS:
                raise ValueError(
                    f"Optimization method scipy:{self.opt_method} "
                    "does not support bounds."
                )

            # Handle constraints
            if (
                self.optimizable.has_constraints
                and self.opt_method not in SUPPORTS_CONSTRAINTS
            ):
                raise ValueError(
                    f"Optimization method scipy:{self.opt_method} "
                    "does not support constraints."
                )

            if self.optimizable.has_constraints:
                constraints = jax.jit(self.optimizable.constraints_flat)
                constraints_jac = jax.jit(jax.jacrev(constraints))
                constraints = sciopt.NonlinearConstraint(
                    constraints, 0.0, jnp.inf, jac=constraints_jac
                )
            else:
                constraints = None

            if self.metrics_writer is not None:
                cb = (
                    self._scipy_callback_new
                    if self.opt_method in MINIMIZE_METHODS_NEW_CB
                    else partial(self._scipy_callback_legacy, objective)
                )
            else:
                cb = None

            opt_res: "sciopt.OptimizeResult" = sciopt.minimize(
                objective,
                params,
                method=self.opt_method,
                jac=jac if use_jac else None,
                bounds=bounds,
                constraints=constraints,
                tol=self.tol,
                options=self.opt_method_config,
                callback=cb,
            )

            params = opt_res.x

            # Show the raw information from scipy. This can help with debugging.
            logger.info("Optimization result:\n%s", opt_res)

            if not opt_res.success:
                logger.warning("Optimization did not converge: %s", opt_res.message)

        self.optimal_params = self.optimizable.unflatten_params(params)
        if self.optimizable.transformation is not None:
            self.optimal_params = self.optimizable.transformation.inverse_transform(
                self.optimal_params
            )
        return self.optimal_params

    # NOTE: if this turns out to be too expensive, we can throttle writes in the
    # MetricsWriter and only compute metrics when we need them.
    def _write_metrics(self, fun, x):
        metrics = {}
        if fun is not None:
            metrics["fun"] = fun
        if x is not None:
            params: dict = self.optimizable.unflatten_params(x)
            for k, v in params.items():
                if np.asarray(v).shape == ():
                    metrics[k] = v
        if len(metrics) > 0:
            self.metrics_writer.write_metrics(**metrics)

    def _scipy_callback_new(self, intermediate_result: "sciopt.OptimizeResult"):
        self._write_metrics(
            intermediate_result.get("fun"), intermediate_result.get("x")
        )

    def _scipy_callback_legacy(self, objective, intermediate_results: np.ndarray):
        fun = objective(intermediate_results)
        self._write_metrics(fun, intermediate_results)

optimize()

Run optimization

Source code in collimator/optimization/framework/optimizers_scipy.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def optimize(self):
    """Run optimization"""
    params = self.optimizable.params_0_flat
    objective = jax.jit(self.optimizable.objective_flat)

    if self.use_jax_scipy:
        warnings.warn(
            "`use_jax_scipy` is True. JAX's version of optimize.minimize will be "
            "used. Consequently, `opt_method` will be set of `BFGS` and autodiff "
            "will be used for gradient computation. Constraints and bounds will "
            "be ignored. If you want to use scipy's version of minimize, set "
            " `use_jax_scipy` to False."
        )
        opt_res = jaxopt.minimize(
            objective,
            params,
            method="BFGS",
            tol=self.tol,
            options=self.opt_method_config,
        )
        params = opt_res.x

    else:
        use_jac = False
        if self.opt_method in ACCEPTS_GRAD and self.use_autodiff_grad:
            jac = jax.jit(jax.grad(objective))
            use_jac = True

        # Handle bounds
        bounds = self.optimizable.bounds_flat

        # Jobs from UI would put (-jnp.inf, jnp.inf) as defualt bounds. The user
        # may also have specified bounds this way. Scipy expects `None` to imply
        # unboundedness.
        if bounds is not None:
            bounds = [
                (
                    None if b[0] == -jnp.inf else b[0],
                    None if b[1] == jnp.inf else b[1],
                )
                for b in bounds
            ]

            # Check if all bounds are None, i.e. no bounds at all, and hence
            # algorithms that do not support bounds can be used.
            flattened_bounds = [element for tup in bounds for element in tup]
            all_none = all(element is None for element in flattened_bounds)
            bounds = None if all_none else bounds

        if bounds is not None and self.opt_method not in SUPPORTS_BOUNDS:
            raise ValueError(
                f"Optimization method scipy:{self.opt_method} "
                "does not support bounds."
            )

        # Handle constraints
        if (
            self.optimizable.has_constraints
            and self.opt_method not in SUPPORTS_CONSTRAINTS
        ):
            raise ValueError(
                f"Optimization method scipy:{self.opt_method} "
                "does not support constraints."
            )

        if self.optimizable.has_constraints:
            constraints = jax.jit(self.optimizable.constraints_flat)
            constraints_jac = jax.jit(jax.jacrev(constraints))
            constraints = sciopt.NonlinearConstraint(
                constraints, 0.0, jnp.inf, jac=constraints_jac
            )
        else:
            constraints = None

        if self.metrics_writer is not None:
            cb = (
                self._scipy_callback_new
                if self.opt_method in MINIMIZE_METHODS_NEW_CB
                else partial(self._scipy_callback_legacy, objective)
            )
        else:
            cb = None

        opt_res: "sciopt.OptimizeResult" = sciopt.minimize(
            objective,
            params,
            method=self.opt_method,
            jac=jac if use_jac else None,
            bounds=bounds,
            constraints=constraints,
            tol=self.tol,
            options=self.opt_method_config,
            callback=cb,
        )

        params = opt_res.x

        # Show the raw information from scipy. This can help with debugging.
        logger.info("Optimization result:\n%s", opt_res)

        if not opt_res.success:
            logger.warning("Optimization did not converge: %s", opt_res.message)

    self.optimal_params = self.optimizable.unflatten_params(params)
    if self.optimizable.transformation is not None:
        self.optimal_params = self.optimizable.transformation.inverse_transform(
            self.optimal_params
        )
    return self.optimal_params

Trainer

Base class for optimizing model parameters via simulation.

Should probably get a more descriptive name once we're doing other kinds of training...

Source code in collimator/optimization/training.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class Trainer:
    """Base class for optimizing model parameters via simulation.

    Should probably get a more descriptive name once we're doing other kinds
    of training...
    """

    def __init__(
        self,
        simulator: Simulator,
        context,
        optimizer="adamw",
        lr=1e-3,
        print_every=10,
        clip_range=(-10.0, 10.0),
        **opt_kwargs,
    ):
        self.simulator = simulator
        self.context = context
        self.opt_state = None

        # See https://optax.readthedocs.io/en/latest/api.html for supported optimizers
        self.optimizer = getattr(optax, optimizer)(lr, **opt_kwargs)

        self.print_every = print_every
        self.clip_range = clip_range

    @abc.abstractmethod
    def optimizable_parameters(self, context):
        """Extract optimizable model-specific parameters from the context.

        These should be in the form of a PyTree (e.g. tuple, dict, array, etc)
        and should be the first arguments to `prepare_context`.
        """
        pass

    @abc.abstractmethod
    def prepare_context(self, context, *data, key=None):
        """Model-specific updates to incorporate the sample data and parameters.

        `data` should be the combination of the output of `optimizable_parameters`
        along with all the per-simulation "training data".  Parameters will
        update once per epoch, and training data will update once per sample.
        """
        pass

    @abc.abstractmethod
    def evaluate_cost(self, context):
        """Model-specific cost function, evaluated on final context"""
        pass

    def make_forward(self, start_time, stop_time):
        """Create a generic forward pass through the simulation, returning loss"""

        # Take all the data and model parameters, run a simulation, return loss.
        def _simulate(key, *data):
            context = self.context.with_time(start_time)
            context = self.prepare_context(context, *data, key=key)
            results = self.simulator.advance_to(stop_time, context)
            return self.evaluate_cost(results.context)

        return _simulate

    def make_loss_fn(self, forward, params):
        """Create a loss function based on a forward pass of the simulation

        `params` here can be any PyTree - it will get flattened to a single array
        """
        # Flatten all optimizable parameters into a single array
        p0, unflatten = ravel_pytree(params)

        # Define the loss as the mean cost function over the data set
        def _loss(p, key, *batch_data):
            # Map the forward pass over all the data points and return the loss
            loss = batch_scan(partial(forward, key, unflatten(p)), *batch_data)
            return loss

        # JIT compile the loss function and return the initial parameter
        # array and unflatten function
        return jax.jit(_loss), p0, unflatten

    def train(
        self,
        training_data,
        sim_start_time,
        sim_stop_time,
        epochs=100,
        key=None,
        params=None,
        opt_state=None,
    ):
        """Run the optimization loop over the training data"""

        if (
            self.simulator.max_major_steps is None
            or self.simulator.max_major_steps <= 0
        ):
            self.simulator.max_major_steps = estimate_max_major_steps(
                self.simulator.system,
                (sim_start_time, sim_stop_time),
                self.simulator.max_major_step_length,
            )

        if key is None:
            key = jax.random.PRNGKey(np.random.randint(0, 2**32, dtype=np.int64))

        # Create a function to evaluate the forward pass through the simulation
        forward = self.make_forward(sim_start_time, sim_stop_time)

        # Pull out the optimizable parameters from the context
        if params is None:
            params = self.optimizable_parameters(self.context)

        # Initialize the optimizer and create the loss function
        loss, p, unflatten = self.make_loss_fn(forward, params)

        if opt_state is None:
            opt_state = self.optimizer.init(p)

        self.opt_state = opt_state

        @jax.jit
        def opt_step(p, opt_state, key, batch_data):
            key, subkey = jax.random.split(key)
            if batch_data:
                loss_value, grads = jax.value_and_grad(loss)(p, subkey, *batch_data)
            else:
                loss_value, grads = jax.value_and_grad(loss)(p, subkey)

            grads = jnp.clip(grads, *self.clip_range)

            updates, opt_state = self.optimizer.update(grads, opt_state, p)
            p = optax.apply_updates(p, updates)
            return p, opt_state, key, loss_value

        def _scan_fun(carry, batch_data):
            p, opt_state, key, loss_value = opt_step(*carry, batch_data)
            return (p, opt_state, key), loss_value

        # Run the optimization loop
        for epoch in range(epochs):
            (p, self.opt_state, key), batch_loss = jax.lax.scan(
                _scan_fun, (p, self.opt_state, key), training_data
            )

            if epoch % self.print_every == 0:
                logger.info("Epoch %s, loss: %s", epoch, jnp.mean(batch_loss))

        # Return the optimized parameters
        return unflatten(p)

evaluate_cost(context) abstractmethod

Model-specific cost function, evaluated on final context

Source code in collimator/optimization/training.py
89
90
91
92
@abc.abstractmethod
def evaluate_cost(self, context):
    """Model-specific cost function, evaluated on final context"""
    pass

make_forward(start_time, stop_time)

Create a generic forward pass through the simulation, returning loss

Source code in collimator/optimization/training.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def make_forward(self, start_time, stop_time):
    """Create a generic forward pass through the simulation, returning loss"""

    # Take all the data and model parameters, run a simulation, return loss.
    def _simulate(key, *data):
        context = self.context.with_time(start_time)
        context = self.prepare_context(context, *data, key=key)
        results = self.simulator.advance_to(stop_time, context)
        return self.evaluate_cost(results.context)

    return _simulate

make_loss_fn(forward, params)

Create a loss function based on a forward pass of the simulation

params here can be any PyTree - it will get flattened to a single array

Source code in collimator/optimization/training.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def make_loss_fn(self, forward, params):
    """Create a loss function based on a forward pass of the simulation

    `params` here can be any PyTree - it will get flattened to a single array
    """
    # Flatten all optimizable parameters into a single array
    p0, unflatten = ravel_pytree(params)

    # Define the loss as the mean cost function over the data set
    def _loss(p, key, *batch_data):
        # Map the forward pass over all the data points and return the loss
        loss = batch_scan(partial(forward, key, unflatten(p)), *batch_data)
        return loss

    # JIT compile the loss function and return the initial parameter
    # array and unflatten function
    return jax.jit(_loss), p0, unflatten

optimizable_parameters(context) abstractmethod

Extract optimizable model-specific parameters from the context.

These should be in the form of a PyTree (e.g. tuple, dict, array, etc) and should be the first arguments to prepare_context.

Source code in collimator/optimization/training.py
70
71
72
73
74
75
76
77
@abc.abstractmethod
def optimizable_parameters(self, context):
    """Extract optimizable model-specific parameters from the context.

    These should be in the form of a PyTree (e.g. tuple, dict, array, etc)
    and should be the first arguments to `prepare_context`.
    """
    pass

prepare_context(context, *data, key=None) abstractmethod

Model-specific updates to incorporate the sample data and parameters.

data should be the combination of the output of optimizable_parameters along with all the per-simulation "training data". Parameters will update once per epoch, and training data will update once per sample.

Source code in collimator/optimization/training.py
79
80
81
82
83
84
85
86
87
@abc.abstractmethod
def prepare_context(self, context, *data, key=None):
    """Model-specific updates to incorporate the sample data and parameters.

    `data` should be the combination of the output of `optimizable_parameters`
    along with all the per-simulation "training data".  Parameters will
    update once per epoch, and training data will update once per sample.
    """
    pass

train(training_data, sim_start_time, sim_stop_time, epochs=100, key=None, params=None, opt_state=None)

Run the optimization loop over the training data

Source code in collimator/optimization/training.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def train(
    self,
    training_data,
    sim_start_time,
    sim_stop_time,
    epochs=100,
    key=None,
    params=None,
    opt_state=None,
):
    """Run the optimization loop over the training data"""

    if (
        self.simulator.max_major_steps is None
        or self.simulator.max_major_steps <= 0
    ):
        self.simulator.max_major_steps = estimate_max_major_steps(
            self.simulator.system,
            (sim_start_time, sim_stop_time),
            self.simulator.max_major_step_length,
        )

    if key is None:
        key = jax.random.PRNGKey(np.random.randint(0, 2**32, dtype=np.int64))

    # Create a function to evaluate the forward pass through the simulation
    forward = self.make_forward(sim_start_time, sim_stop_time)

    # Pull out the optimizable parameters from the context
    if params is None:
        params = self.optimizable_parameters(self.context)

    # Initialize the optimizer and create the loss function
    loss, p, unflatten = self.make_loss_fn(forward, params)

    if opt_state is None:
        opt_state = self.optimizer.init(p)

    self.opt_state = opt_state

    @jax.jit
    def opt_step(p, opt_state, key, batch_data):
        key, subkey = jax.random.split(key)
        if batch_data:
            loss_value, grads = jax.value_and_grad(loss)(p, subkey, *batch_data)
        else:
            loss_value, grads = jax.value_and_grad(loss)(p, subkey)

        grads = jnp.clip(grads, *self.clip_range)

        updates, opt_state = self.optimizer.update(grads, opt_state, p)
        p = optax.apply_updates(p, updates)
        return p, opt_state, key, loss_value

    def _scan_fun(carry, batch_data):
        p, opt_state, key, loss_value = opt_step(*carry, batch_data)
        return (p, opt_state, key), loss_value

    # Run the optimization loop
    for epoch in range(epochs):
        (p, self.opt_state, key), batch_loss = jax.lax.scan(
            _scan_fun, (p, self.opt_state, key), training_data
        )

        if epoch % self.print_every == 0:
            logger.info("Epoch %s, loss: %s", epoch, jnp.mean(batch_loss))

    # Return the optimized parameters
    return unflatten(p)

Transform

Bases: ABC

Base class for transformations.

Source code in collimator/optimization/framework/base/transformations.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Transform(ABC):
    """Base class for transformations."""

    @abstractmethod
    def transform(self, params: dict) -> dict:
        """
        Take original parameters dict {key:value} and output a dict with identical keys
        but transformed `values`.
        """
        pass

    @abstractmethod
    def inverse_transform(self, params: dict) -> dict:
        """
        Take transformed parameters dict {key:value} and output a dict with identical
        keys but inverse-transformed `values`.
        """
        pass

inverse_transform(params) abstractmethod

Take transformed parameters dict {key:value} and output a dict with identical keys but inverse-transformed values.

Source code in collimator/optimization/framework/base/transformations.py
20
21
22
23
24
25
26
@abstractmethod
def inverse_transform(self, params: dict) -> dict:
    """
    Take transformed parameters dict {key:value} and output a dict with identical
    keys but inverse-transformed `values`.
    """
    pass

transform(params) abstractmethod

Take original parameters dict {key:value} and output a dict with identical keys but transformed values.

Source code in collimator/optimization/framework/base/transformations.py
12
13
14
15
16
17
18
@abstractmethod
def transform(self, params: dict) -> dict:
    """
    Take original parameters dict {key:value} and output a dict with identical keys
    but transformed `values`.
    """
    pass