Wednesday, November 15, 2017

Observations of a Keras developer learning Pytorch


In terms of toolkits, my Deep Learning (DL) journey started with using Caffe pre-trained models for transfer learning. This was followed by a brief dalliance with Tensorflow (TF), first as a vehicle for doing the exercises on the Udacity Deep Learning course, then retraining some existing TF models on our own data. Then I came across Keras, and like many others, absolutely fell in love with the simplicity, elegance and power of its object-oriented layer centric API. Most of the deep learning work I have done over the past couple of years has been with Keras, and it has, much like Larry Wall intended with Perl, made easy things easy and hard things possible for me.

Few months ago (7 according to github), I caught the polyglot bug (but with respect to deep learning toolkits, hence polyDLot), and decided to test the waters by implementing MNIST handwritten digit classification with multiple different toolkits. One of these toolkits was Pytorch, which promises imperative programming and dynamic computation graphs. Although the MNIST tasks did not exploit Pytorch's ability to build dynamic computation graphs, I thought it was quite unique in its positioning. In terms of verbosity, it is somewhere in between Keras and TF, but in terms of flexibility and power, it seems to offer all the benefits of a low level framework like TF with a simpler and more intuitive interface.

Lately, I have been hearing a lot about Pytorch, such as the release of the Allen NLP library, an open-source Natural Language Processing (NLP) research library from Allen AI that was built on top of Pytorch. Also Jeremy Howard, the brain behind the fast.ai MOOCs whose goal is to make neural nets uncool again, has written about his reasons for introducing Pytorch for fast.ai. Part-1 of this course was based on Keras, Part-2 is based on a combination of TF and Pytorch. There are a few other mentions as well, but you get the idea. By the way, Part-2 of the fast.ai course is now open to the public in case you were waiting for that.

My own interests are around applying DL to NLP, and I have been hitting a few Keras limits of my own for some things lately. Nothing insurmountable at the moment that a custom layer cannot fix, but I figured that it might be worth exploring Pytorch a bit more, especially to get familiar with recurrent models. So that's what I did, and this post describes the experience.

I used the book Long Short Term Memory Networks with Python by Jason Brownlee as my source of toy examples to implement in Pytorch. Jason Brownlee runs the Machine Learning Mastery site and is the author of multiple books on Machine Learning (ML) and DL. He also sends out a regular newsletter with practical tips on ML/DL. Like his other DL books, the book provides Keras code for each of the toy examples.

I built Pytorch implementations for six toy networks, each in its own Jupyter notebook. Each notebook begins with a brief problem description, loosely extracted from the book. I have tried to make it descriptive, but if it is insufficient, please look at the code or read the description in its original. Also if you are looking for the Keras implementation and information beyond just the basic description for these toy examples, I would recommend purchasing the book. In addition to these examples, the book has good advice on the things to watch out for when building recurrent networks. I (obviously) bought a copy, and I think it was definitely worth the price. Here are the examples:

  • 06-echo-sequence-prediction.ipynb - the network is fed a fixed-size sequence of random integers, and trained to predict the integer at a specific (but unknown to the network) index in the input.
  • 07-damped-sine-wave-prediction.ipynb - the network is fed fixed-size of points on damped sine waves of varying amplitudes and periodicity, and trained to predict the value for an unknown damped sine wave at the next time step given a sequence of previous values.
  • 08-moving-square-video-prediction.ipynb - a combined CNN-LSTM network that takes a sequence of images representing the movement of a point from one end of a square to another, and predicts the direction of the movement for a new sequence of images.
  • 09-addition-prediction.ipynb - an encoder-decoder network to solve addition problems represented as a sequence of digits joined by the plus sign. Output is the stringified value of the sum.
  • 10-cumsum-prediction.ipynb - a network that takes a sequence of random values between 0 and 1, and outputs 0 or 1 depending on whether the cumulative sum of the values seen so far is below or above a specific (but unknown to the network) threshold value.
  • 11-shape-generation.ipynb - a network trained on a sequence of real-valued (x, y) coordinate pairs representing a rectangle. The trained network is then used to generate polygon shapes that (should) look like rectangles.

And finally, here comes the observations I promised in the title of this post. These examples do explore Pytorch capabilities better than the MNIST examples, but it still doesn't actually exploit its capabilities of creating dynamic computation graphs.


  • Models are classes - in Keras, you manipulate pre-built layer classes like Lego blocks using either the Sequential or Functional API. In Pytorch, you set up your network as a class which extends torch.nn.Module. Pytorch provides you layers as building blocks similar to Keras, but you typically reference them in the class's __init__() method and define the flow in its forward() method. Because you have access to all of Python's features as opposed to simple function calls, this can result in much more expressive flows.
  • TimeDistributed and RepeatVector are missing - these two components are used in Keras to declare a transformation and distribute it over time, or to replicate a vector to feed into an LSTM. Neither component exists in Pytorch because they can be easily implemented using code.
  • Less insulation from component internals - the Keras API hides a lot of the messy details from the casual user. Components have sensible defaults, so you can start simple and tweak more and more parameters as you gain experience. On the other hand, the TF API gives you complete control (and arguably more than enough rope to hang yourself), forcing you to think of all parameters at the level of matrix multiplication. While Pytorch does not go that far, it does require you to understand in general what is going on inside each component. For example, its LSTM module allows for multiple layers, and a Bidirectional LSTM (achieved by setting the parameter bidirectional=True) is internally represented as a stack of 2 LSTMs - you are required to know this so you can set the dimensions of the hidden state (h) signal correctly. Another example is the need to explicitly specify the output sizes after convolution for CNN layers.
  • Fitting model is multi-step process - fitting a model in Pytorch consists of initializing gradients at the start of each batch of training, running hte batch forward through the model, running the gradient backward, computing the loss and making the weight update (optimizer.step()). I don't know if this process varies enough to justify having these split out. At least in my case, the training loop is practically identical across all my examples.
  • Torch tensors interop with Numpy variables - Most Keras developers never have to worry about TF/Theano and Numpy interop, at least not unless they start using the backend API. Once they do, though, they have to understand the whole concept of TF sessions in order to interoperate between TF tensors and Numpy variables. Pytorch interop is actually much simpler, there are just two operations, one to switch a Torch tensor (a Variable object) to Numpy, and another one to go in the opposite direction.
  • GPU/CPU mode not transparent - both Keras and TF transparently use the GPU if it exists. For Pytorch, you have to explicitly check for this every time you move between torch tensors and numpy variables. This clutters up the code and can be a bit error prone if you move back and forth between CPU (for development) and GPU (for deployment) environments. Although I suppose we could build wrapper functions and use them instead.
  • Channel first always for images - TF (and by extension Keras) offers the user a choice of representing an image as (N, C, H, W) or (N, H, W, C), or channel-first or channel-last format (here N = batch size, C = number of channels, H = image height, and W = image width). Pytorch is always channel first. I mention it here because I spent some time trying to figure out why my NHWC format tensors weren't working with my network class.
  • Batch first is optional for RNN input - Unlike Keras and TF, where inputs to RNNs are in (N, T, F), Pytorch requires input as (T, N, F) by default (here N = batch size, T = number of timesteps, F = number of features). However, you can switch over to the more familiar (N, T, F) format by setting the batch_first=True parameter. This simplifies some of the code for batch manipulation during training.

A side effect of the more complex network definition is that I have almost standardized on a debugging strategy that I previously only used for Keras custom layers. The idea is that you send a random input signal of the required dimensions into the network and verify that the network returns a tensor of the required dimensions. Very often, this will expose dimensional inaccuracies inside the network, saving you some debugging grief during training.

The other thing I wanted to note is that I deliberately used the epoch/batch style training that I have grown used to with Keras, even though it meant slightly more code. The style I have seen in Pytorch examples is to do a flat number of iterations instead. Now that I think about this some more, this may be a non-issue, since the training loop appears to be common enough so it can be factored out.

And that is all I had for today. I hope you find my observations useful if and when you, as a Keras or TF developer, decide to pick up Pytorch as well.

No comments:

Post a Comment

Comments are moderated to prevent spam.