Using external optimizers

The PyTorch package ships already with a number of optimizers. These however are typically suited for problems with large parameter spaces (several hundred or more). For smaller problems other optimizers, e.g. the Levenberg-Marquardt algorithm, are more efficient. The Scipy package offers various other optimizers at their scipy.optimize.least_squares function. This notebook shows how to use such external optimizers together with the DiPAS framework.

In PyTorch the job of the optimizer is to update the parameters of the simulation, given their gradients with respect to some cost function. That is it performs the following steps:

  1. Compute the new parameter values (based on the current values and gradients),
  2. Apply these new values, i.e. make the changes effective.

We use the external optimizer to perform step (1) and then we apply the changes manually (which is easy).

Setup

In this example we will estimate quadrupole gradient errors based on ORM matching. In the following we load the lattice, set some random errors and compute the reference ORM.

In [1]:
from dipas.build import from_file
import dipas.compute as compute
from dipas.elements import elements, Marker, Quadrupole, VKicker, VMonitor, tensor
import numpy as np
import torch


elements['multipole'] = lambda **kwargs: Marker()  # all multipoles are turned off

lattice = from_file('example.madx')
quadrupoles = lattice[Quadrupole]

k1 = np.array([q.k1.numpy() for q in quadrupoles])
errors = np.random.normal(scale=0.01*abs(k1), size=len(quadrupoles))

for q, dk1 in zip(quadrupoles, errors):
    q.dk1 = tensor(dk1)
    q.update_transfer_map()
    
orm_ref = compute.orm(lattice, kickers=VKicker, monitors=VMonitor, order=1)[1]  # vertical response only

Let's recall the signature of least_squares. It expects a function func which receives N parameters and outputs M residuals, both as Numpy arrays. The default "lm" (Levenberg-Marquardt) solver requires that N <= M. In our case the input N will be the current estimate of the quadrupole errors and the output M is the vertical ORM. The lattice contains 36 quadrupoles, 12 vertical correctors and 12 BPMs. So we have N = 36 and M = 12**2.

Within that function func we need to make the new parameter estimate effective, similar to above, and then compute the ORM and return it as a flattened Numpy array.

In [2]:
print(f'# Quadruoles: {len(quadrupoles)}')
print(f'# Correctors: {len(lattice[VKicker])}')
print(f'# BPMs: {len(lattice[VMonitor])}')


def orm_func(estimate):
    for q, dk1 in zip(quadrupoles, estimate):
        q.dk1 = tensor(dk1)
        q.update_transfer_map()
    orm = compute.orm(lattice, kickers=VKicker, monitors=VMonitor, order=1)[1]
    residuals = (orm - orm_ref).numpy().ravel()
    return residuals
# Quadruoles: 36
# Correctors: 12
# BPMs: 12

The above function is sufficient however the algorithm would still need to estimate the Jacobian itself via finite difference approximation. The number of forward passes, i.e. ORM computations, is approximately on the order of the number of parameters, i.e. 37 in our case (1 baseline + 1 offset for each parameter). This means the algorithm would need to compute 37 different ORMs before it can estimate the Jacobian and provide a new parameter estimate. Here we can be more efficient by reusing the computation graph that is built by the simulation and compute all the gradients with a single forward-backward pass. For this we can use torch.autograd.functional.jacobian. In order to do everything at once and also cache the result of the computation we can wrap everything in a dedicated class.

In [3]:
from collections import deque

from torch.autograd.functional import jacobian


class MatchORM:
    def __init__(self, lattice, orm_ref):
        self.lattice = lattice
        self.orm_ref = orm_ref
        self.quadrupoles = lattice[Quadrupole]
        self.jac_cache = deque(maxlen=1)
        self.orm_cache = deque(maxlen=1)
        self.step = 0
        
    def __call__(self, estimate):
        """This function gets invoked by least_squares, i.e. Numpy inputs/outputs."""
        self.step += 1
        jac = jacobian(self._compute, torch.from_numpy(estimate))
        self.jac_cache.append(jac.numpy())
        residuals = self.orm_cache.pop().detach().numpy()
        print(f'[{self.step:02d}] mse = {np.mean(residuals**2)}')
        return residuals
    
    def _compute(self, estimate):
        """This function gets invoked by jacobian, i.e. PyTorch inputs/outputs."""
        for q, dk1 in zip(self.quadrupoles, estimate):
            q.dk1 = dk1
            q.update_transfer_map()
        orm = compute.orm(lattice, kickers=VKicker, monitors=VMonitor, order=1)[1]
        residuals = torch.flatten(orm - self.orm_ref)
        self.orm_cache.append(residuals)
        return residuals
    
    def jac(self, estimate):
        """This function gets invoked by least_squares, i.e. Numpy output is required."""
        return self.jac_cache.pop()

Now we can invoke the optimizer by passing it an instance of this class:

In [4]:
from scipy.optimize import least_squares


match = MatchORM(lattice, orm_ref)

result = least_squares(match, jac=match.jac, x0=np.zeros_like(errors), method='lm')
print(result)
[01] mse = 2.3596129695680395
[02] mse = 2.3596129695680395
[03] mse = 2.3596129695680395
[04] mse = 0.04447263051378097
[05] mse = 0.019456070405029766
[06] mse = 6.416326049531289e-06
[07] mse = 3.7826057702729345e-09
[08] mse = 1.895252659297642e-19
[09] mse = 1.427573236173852e-27
[10] mse = 8.101062664219795e-28
 active_mask: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        cost: 5.832765118238252e-26
         fun: array([ 0.00000000e+00,  5.32907052e-15, -5.32907052e-15,  8.88178420e-15,
        1.24344979e-14, -1.50990331e-14, -2.13162821e-14, -2.79221091e-14,
        5.32907052e-15,  0.00000000e+00, -4.08562073e-14,  1.77635684e-15,
        2.39808173e-14, -7.10542736e-15, -2.48689958e-14, -3.55271368e-15,
        2.13162821e-14, -8.88178420e-15, -3.64153152e-14, -1.24344979e-14,
       -1.99840144e-14, -5.15143483e-14, -2.22044605e-15,  7.10542736e-15,
        3.55271368e-15, -2.22044605e-14,  8.88178420e-15, -1.42108547e-14,
       -1.59872116e-14,  1.24344979e-14,  2.30926389e-14,  2.35367281e-14,
        7.10542736e-15,  1.73194792e-14,  2.48689958e-14,  4.44089210e-15,
       -2.57571742e-14,  1.42108547e-14,  3.55271368e-15, -1.95399252e-14,
       -2.13162821e-14,  7.10542736e-15,  3.10862447e-14, -1.77635684e-15,
        3.61932706e-14,  6.03961325e-14, -1.33226763e-15, -1.06581410e-14,
       -3.55271368e-15,  1.06581410e-14,  1.77635684e-15,  4.79616347e-14,
        1.59872116e-14, -2.30926389e-14, -2.66453526e-14,  1.77635684e-15,
       -2.66453526e-14, -3.13082893e-14, -2.30926389e-14,  1.73194792e-14,
        1.73194792e-14, -2.30926389e-14, -6.21724894e-15,  1.42108547e-14,
       -1.24344979e-14, -1.15463195e-14, -5.15143483e-14, -7.10542736e-15,
       -1.33226763e-14, -5.32907052e-14, -1.19904087e-14, -3.55271368e-15,
       -1.06581410e-14, -1.92068583e-14, -1.77635684e-15, -4.66293670e-14,
       -3.19744231e-14, -8.88178420e-15,  2.84217094e-14,  3.55271368e-15,
        3.37507799e-14,  3.73034936e-14,  4.08562073e-14,  1.68753900e-14,
       -2.17603713e-14,  5.86197757e-14,  1.32671651e-14, -2.84217094e-14,
        4.84057239e-14,  3.73034936e-14,  4.88498131e-14,  7.99360578e-15,
       -3.73034936e-14,  4.61852778e-14,  1.77635684e-14, -1.77635684e-14,
        8.88178420e-15,  2.57571742e-14,  0.00000000e+00,  3.44169138e-14,
        3.55271368e-14,  2.08721929e-14, -4.79616347e-14,  0.00000000e+00,
       -4.44089210e-14, -3.10862447e-14, -3.55271368e-14, -4.44089210e-15,
        1.24344979e-14, -6.03961325e-14, -5.32907052e-15,  1.42108547e-14,
       -6.90558721e-14, -6.39488462e-14,  2.66453526e-15, -1.24344979e-14,
        4.08562073e-14, -3.55271368e-14, -2.48689958e-14,  2.30926389e-14,
       -1.42108547e-14, -9.76996262e-15,  7.10542736e-15, -2.93098879e-14,
       -1.95399252e-14, -8.99280650e-15,  5.50670620e-14, -3.55271368e-15,
        1.42108547e-14,  2.48689958e-14,  2.48689958e-14,  8.88178420e-15,
       -1.24344979e-14,  6.75015599e-14, -5.32907052e-15, -1.42108547e-14,
        8.52651283e-14,  6.57252031e-14, -9.99200722e-15,  1.77635684e-15,
       -4.75175455e-14,  4.61852778e-14,  4.26325641e-14, -3.73034936e-14])
        grad: array([ 3.87809560e-11,  9.41901644e-11,  4.08196496e-11, -2.94468927e-11,
       -7.38044359e-11, -3.31068032e-11,  1.07038273e-11,  3.19154575e-11,
        1.49221711e-11, -2.63053813e-11, -6.83893696e-11, -3.06306051e-11,
        3.21767308e-11,  8.42847027e-11,  3.79789357e-11, -5.96603917e-11,
       -1.47527321e-10, -6.44967040e-11,  3.37752664e-11,  8.59025720e-11,
        3.78130009e-11, -5.35994856e-11, -1.30448338e-10, -5.51987925e-11,
        2.30886293e-11,  5.42990156e-11,  2.14683893e-11, -4.96429512e-11,
       -1.13401927e-10, -4.85920750e-11,  3.48510858e-11,  8.64770499e-11,
        3.79478486e-11, -4.45879025e-11, -1.11363277e-10, -4.75798380e-11])
         jac: array([[  38.68344469,   76.53415205,   33.34548917, ...,   30.03831099,
          96.83563101,   47.25805687],
       [  39.26902437,   94.05279423,   40.22871482, ...,  -34.32228054,
         -89.98246921,  -39.05381579],
       [ -48.90590701,  -96.19572536,  -36.84507096, ...,  -21.05573832,
         -73.2725364 ,  -37.02738227],
       ...,
       [ -37.01005316,  -66.87722996,  -24.06253685, ...,  -47.28657934,
        -117.95490619,  -50.10215072],
       [ -62.12543355, -144.03738216,  -60.89289824, ...,   33.55563152,
          62.51086505,   22.10732741],
       [  55.86103047,  110.64148798,   42.57719828, ...,   45.47746351,
          98.50820674,   43.19884077]])
     message: '`xtol` termination condition is satisfied.'
        nfev: 8
        njev: 7
  optimality: 1.4752732096279823e-10
      status: 3
     success: True
           x: array([ 0.00792058,  0.00719378,  0.00116074,  0.00029487,  0.00290322,
       -0.0005206 ,  0.00414219,  0.00465046,  0.00059435,  0.00303476,
        0.00179031, -0.00035887,  0.00297726,  0.00078953, -0.00053425,
       -0.00600123,  0.00011263,  0.00023978, -0.00287374, -0.00282262,
       -0.00067364,  0.00067982, -0.00018771,  0.00066961,  0.00207471,
       -0.00108232, -0.00017001,  0.00115687,  0.00464513, -0.00012917,
        0.00098835,  0.00054944, -0.00027458,  0.00034512, -0.00637738,
        0.00036608])

Finally let's compare the results with the actual errors:

In [5]:
%matplotlib inline
import matplotlib.pyplot as plt


fig, ax = plt.subplots(figsize=(18, 6))
ax.set(xlabel='Quadrupole #', ylabel='Quadrupole gradient error [1/m^2]')
width = 0.25
x = np.arange(len(errors))
ax.bar(x - width/2, errors, width, label='Reference')
ax.bar(x + width/2, result.x, width, label='Estimate')
ax.set_xticks(x)
ax.legend()
Out[5]:
<matplotlib.legend.Legend at 0x7f5b83ae7210>
In [ ]: