Saturday, December 19, 2020

First steps with Pytorch Lightning

Some time back, Quora routed a "Keras vs. Pytorch" question to me, which I decided to ignore because it seemed too much like flamebait to me. Couple of weeks back, after discussions with colleagues and (professional) acquaintances who had tried out libraries like Catalyst, Ignite, and Lightning, I decided to get on the Pytorch boilerplate elimination train as well, and tried out Pytorch Lightning. As I did so, my thoughts inevitably went back to the Quora question, and I came to the conclusion that, in their current form, the two libraries and their respective ecosystems are more similar than they are different, and that there is no technological reason to choose one over the other. Allow me to explain.

Neural networks learn using Gradient Descent. The central idea behind Gradient Descent can be neatly encapsulated in the equation below (extracted from the same linked Gradient Descent article), and is referred to as the "training loop". Of course, there are other aspects of neural networks, such as model and data definition, but it is the training loop where the differences in the earlier versions of the two libraries and their subsequent coming together are most apparent. So I will mostly talk about the training loop here.

Keras was initially conceived of as a high level API over the low level graph based APIs from Theano and Tensorflow. Graph APIs allow the user to first define the computation graph and then execute it. Once the graph is defined, the library will attempt to build the most efficient representation for the graph before execution. This makes the execution more efficient, but adds a lot of boilerplate to the code, and makes it harder to debug if something goes wrong. The biggest success of Keras in my opinion is its ability to hide the graph API almost completely behind an elegant API. In particular, its "training loop" looks like this:

1
2
model.compile(optimizer=optimizer, loss=loss_fn, metrics=[train_acc])
model.fit(Xtrain, ytrain, epochs=epochs, batch_size=batch_size)

Of course, the fit method has many other parameters as well, but at its most complex, it is a single line call. And, this is probably all that is needed for most simple cases. However, as networks get slightly more complex, with maybe multiple models or loss functions, or custom update rules, the only option for Keras used to be to drop down to the underlying Tensorflow or Theano code. In these situations, Pytorch appears really attractive, with the power, simplicity, and readability of its training loop.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
dataloader = DataLoader(Xtrain, batch_size=batch_size)
for epoch in epochs:
    for batch in dataloader:
        X, y = batch
        logits = model(X)
        loss = loss_fn(logits, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # aggregate metrics
        train_acc(logits, loss)

        # evaluate validation loss, etc.

However, with the release of Tensorflow 2.x, which included Keras as its default API through the tf.keras package, it is now possible to do something identical with Keras and Tensorflow as well.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
dataset = Dataset.from_tensor_slices(Xtrain).batch(batch_size)
for epoch in epochs:
    for batch in dataset:
        X, y = batch
        with tf.GradientTape as tape:
            logits = model(X)
            loss = loss_fn(y_pred=logits, y_true=y)
        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # aggregate metrics
        train_acc(logits, y)

In both cases, developers accept having to deal with some amount of boilerplate in return for additional power and flexibility. The approach taken by each of the three Pytorch add-on libraries I listed earlier, including Pytorch Lightning, is to create a Trainer object. The trainer models the training loop as an event loop with hooks into which specific functionality can be injected as callbacks. Functionality in these callbacks would be executed at specific points in the training loop. So a partial LightningModule subclass for our use case would look something like this, see the Pytorch Lightning Documentation or my code examples below for more details.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class MyLightningModel(pl.LightningModule):
    def __init__(self, args):
        # same as Pytorch nn.Module subclass __init__()

    def forward(self, x):
        # same as Pytorch nn.Module subclass forward()

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = loss_fn(logits, y)
        acc = self.train_acc(logits, y)
        return loss

    def configure_optimizers(self):
        return self.optimizer

model = MyLightningModel()
trainer = pl.Trainer(gpus=1)
trainer.fit(model, dataloader)

If you think about it, this event loop strategy used by Lightning's trainer.fit() is pretty much how Keras manages to convert its training loop to a single line model.fit() call as well, its many parameters acting as the callbacks that control the training behavior. Pytorch Lightning is just a bit more explicit (and okay, a bit more verbose) about it. In effect, both libraries have solutions that address the other's pain points, so the only reason you would choose one or the other is personal or corporate preference.

In addition to callbacks for each of training, validation, and test steps, there are additional callbacks for each of these steps that will be called at the end of each step and epoch, for example: training_epoch_end() and training_step_end(). Another nice side effect of adopting something like Pytorch Lightning is that you get some of the default functionality of the event loop for free. For example, logging is done to Tensorboard by default, and progress bars are controlled using TQDM. Finally, (and that is the raison d'etre for Pytorch Lightning from the point of view of its developers) it helps you organize your Pytorch code.

To get familiar with Pytorch Lightning, I took three of my old notebooks, each dealing with training one major type of Neural Network architecture (from the old days) -- a fully connected, convolutional, and recurrent network, and converted it to use Pytorch Lightning. You may find it useful to look at, in addition to Pytorch Lightning's extensive documentation, including links to them below.