Skip to content

Optimization

collimator.optimization

Trainer

Base class for optimizing model parameters via simulation.

Should probably get a more descriptive name once we're doing other kinds of training...

Source code in collimator/optimization/training.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class Trainer:
    """Base class for optimizing model parameters via simulation.

    Should probably get a more descriptive name once we're doing other kinds
    of training...
    """

    def __init__(
        self,
        simulator: Simulator,
        context,
        optimizer="adamw",
        lr=1e-3,
        print_every=10,
        clip_range=(-10.0, 10.0),
    ):
        self.simulator = simulator
        self.context = context

        # See https://optax.readthedocs.io/en/latest/api.html for supported optimizers
        self.optimizer = getattr(optax, optimizer)(lr)

        self.print_every = print_every
        self.clip_range = clip_range

    @abc.abstractmethod
    def optimizable_parameters(self, context):
        """Extract optimizable model-specific parameters from the context.

        These should be in the form of a PyTree (e.g. tuple, dict, array, etc)
        and should be the first arguments to `prepare_context`.
        """
        pass

    @abc.abstractmethod
    def prepare_context(self, context, *data):
        """Model-specific updates to incorporate the sample data and parameters.

        `data` should be the combination of the output of `optimizable_parameters`
        along with all the per-simulation "training data".  Parameters will
        update once per epoch, and training data will update once per sample.
        """
        pass

    @abc.abstractmethod
    def evaluate_cost(self, context):
        """Model-specific cost function, evaluated on final context"""
        pass

    def make_forward(self, start_time, stop_time):
        """Create a generic forward pass through the simulation, returning loss"""

        # Take all the data and model parameters, run a simulation, return loss.
        def _simulate(*data):
            context = self.context.with_time(start_time)
            context = self.prepare_context(context, *data)
            results = self.simulator.advance_to(stop_time, context)
            return self.evaluate_cost(results.context)

        return _simulate

    def make_loss_fn(self, forward, params):
        """Create a loss function based on a forward pass of the simulation

        `params` here can be any PyTree - it will get flattened to a single array
        """
        # Flatten all optimizable parameters into a single array
        p0, unflatten = ravel_pytree(params)

        # Define the loss as the mean cost function over the data set
        def _loss(p, *batch_data):
            # Map the forward pass over all the data points and return the loss
            loss = batch_scan(partial(forward, unflatten(p)), *batch_data)
            return loss

        # JIT compile the loss function and return the initial parameter
        # array and unflatten function
        return jax.jit(_loss), p0, unflatten

    def train(self, training_data, sim_start_time, sim_stop_time, epochs=100):
        """Run the optimization loop over the training data"""

        if (
            self.simulator.max_major_steps is None
            or self.simulator.max_major_steps <= 0
        ):
            self.simulator.max_major_steps = estimate_max_major_steps(
                self.simulator.system,
                (sim_start_time, sim_stop_time),
                self.simulator.max_major_step_length,
            )

        # Create a function to evaluate the forward pass through the simulation
        forward = self.make_forward(sim_start_time, sim_stop_time)

        # Pull out the optimizable parameters from the context
        params = self.optimizable_parameters(self.context)

        # Initialize the optimizer and create the loss function
        loss, p, unflatten = self.make_loss_fn(forward, params)
        opt_state = self.optimizer.init(p)

        @jax.jit
        def opt_step(p, opt_state, batch_data):
            if batch_data:
                loss_value, grads = jax.value_and_grad(loss)(p, *batch_data)
            else:
                loss_value, grads = jax.value_and_grad(loss)(p)

            grads = jnp.clip(grads, *self.clip_range)

            updates, opt_state = self.optimizer.update(grads, opt_state, p)
            p = optax.apply_updates(p, updates)
            return p, opt_state, loss_value

        def _scan_fun(carry, batch_data):
            p, opt_state, loss_value = opt_step(*carry, batch_data)
            return (p, opt_state), loss_value

        # Run the optimization loop
        for epoch in range(epochs):
            (p, opt_state), batch_loss = jax.lax.scan(
                _scan_fun, (p, opt_state), training_data
            )

            if epoch % self.print_every == 0:
                print(f"Epoch {epoch}, loss={jnp.mean(batch_loss)}")

        # Return the optimized parameters
        return unflatten(p)

evaluate_cost(context) abstractmethod

Model-specific cost function, evaluated on final context

Source code in collimator/optimization/training.py
79
80
81
82
@abc.abstractmethod
def evaluate_cost(self, context):
    """Model-specific cost function, evaluated on final context"""
    pass

make_forward(start_time, stop_time)

Create a generic forward pass through the simulation, returning loss

Source code in collimator/optimization/training.py
84
85
86
87
88
89
90
91
92
93
94
def make_forward(self, start_time, stop_time):
    """Create a generic forward pass through the simulation, returning loss"""

    # Take all the data and model parameters, run a simulation, return loss.
    def _simulate(*data):
        context = self.context.with_time(start_time)
        context = self.prepare_context(context, *data)
        results = self.simulator.advance_to(stop_time, context)
        return self.evaluate_cost(results.context)

    return _simulate

make_loss_fn(forward, params)

Create a loss function based on a forward pass of the simulation

params here can be any PyTree - it will get flattened to a single array

Source code in collimator/optimization/training.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def make_loss_fn(self, forward, params):
    """Create a loss function based on a forward pass of the simulation

    `params` here can be any PyTree - it will get flattened to a single array
    """
    # Flatten all optimizable parameters into a single array
    p0, unflatten = ravel_pytree(params)

    # Define the loss as the mean cost function over the data set
    def _loss(p, *batch_data):
        # Map the forward pass over all the data points and return the loss
        loss = batch_scan(partial(forward, unflatten(p)), *batch_data)
        return loss

    # JIT compile the loss function and return the initial parameter
    # array and unflatten function
    return jax.jit(_loss), p0, unflatten

optimizable_parameters(context) abstractmethod

Extract optimizable model-specific parameters from the context.

These should be in the form of a PyTree (e.g. tuple, dict, array, etc) and should be the first arguments to prepare_context.

Source code in collimator/optimization/training.py
60
61
62
63
64
65
66
67
@abc.abstractmethod
def optimizable_parameters(self, context):
    """Extract optimizable model-specific parameters from the context.

    These should be in the form of a PyTree (e.g. tuple, dict, array, etc)
    and should be the first arguments to `prepare_context`.
    """
    pass

prepare_context(context, *data) abstractmethod

Model-specific updates to incorporate the sample data and parameters.

data should be the combination of the output of optimizable_parameters along with all the per-simulation "training data". Parameters will update once per epoch, and training data will update once per sample.

Source code in collimator/optimization/training.py
69
70
71
72
73
74
75
76
77
@abc.abstractmethod
def prepare_context(self, context, *data):
    """Model-specific updates to incorporate the sample data and parameters.

    `data` should be the combination of the output of `optimizable_parameters`
    along with all the per-simulation "training data".  Parameters will
    update once per epoch, and training data will update once per sample.
    """
    pass

train(training_data, sim_start_time, sim_stop_time, epochs=100)

Run the optimization loop over the training data

Source code in collimator/optimization/training.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def train(self, training_data, sim_start_time, sim_stop_time, epochs=100):
    """Run the optimization loop over the training data"""

    if (
        self.simulator.max_major_steps is None
        or self.simulator.max_major_steps <= 0
    ):
        self.simulator.max_major_steps = estimate_max_major_steps(
            self.simulator.system,
            (sim_start_time, sim_stop_time),
            self.simulator.max_major_step_length,
        )

    # Create a function to evaluate the forward pass through the simulation
    forward = self.make_forward(sim_start_time, sim_stop_time)

    # Pull out the optimizable parameters from the context
    params = self.optimizable_parameters(self.context)

    # Initialize the optimizer and create the loss function
    loss, p, unflatten = self.make_loss_fn(forward, params)
    opt_state = self.optimizer.init(p)

    @jax.jit
    def opt_step(p, opt_state, batch_data):
        if batch_data:
            loss_value, grads = jax.value_and_grad(loss)(p, *batch_data)
        else:
            loss_value, grads = jax.value_and_grad(loss)(p)

        grads = jnp.clip(grads, *self.clip_range)

        updates, opt_state = self.optimizer.update(grads, opt_state, p)
        p = optax.apply_updates(p, updates)
        return p, opt_state, loss_value

    def _scan_fun(carry, batch_data):
        p, opt_state, loss_value = opt_step(*carry, batch_data)
        return (p, opt_state), loss_value

    # Run the optimization loop
    for epoch in range(epochs):
        (p, opt_state), batch_loss = jax.lax.scan(
            _scan_fun, (p, opt_state), training_data
        )

        if epoch % self.print_every == 0:
            print(f"Epoch {epoch}, loss={jnp.mean(batch_loss)}")

    # Return the optimized parameters
    return unflatten(p)