Image for post
Image for post

Training a CIFAR-10 classifier in the cloud using TensorFlow and Google Colab

Google Colab is a free cloud-based Jupyter notebook environment running completely on Google cloud infrastructure. You can create and execute any Jupyter notebook you want in the browser. Google generously assigns each user a free Tesla K80 with 12GB memory for 12 hours at a time for their small-scale private machine learning needs.

In this article, we will write a Jupyter notebook in order to create a simple object classifier for classifying images from the CIFAR-10 dataset. The classifier uses the TensorFlow Keras API which is an easy-to-use abstraction layer of the TensorFlow API that greatly simplifies machine learning programming while preserving the performance of bare-bones TensorFlow.

Prerequisites

The only thing you need is a Google account if you want to use Google Colab.

Data set

We want to train a classifier on the infamous CIFAR-10 data set. It consists of 60,000 images of everyday objects and their corresponding classes, namely:

airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.

The images have a resolution of 32 x 32 pixels with 3 color channels (RGB). The data set is split into a training dataset (50k samples), a validation data set (5k samples), and a testing data set (5k samples).

Network architecture

Since we are being lazy, we use the handy tensorflow.keras API providing us with useful abstractions for specifying a suitable network architecture. We use a variant of the Resnet architecture, which achieved groundbreaking accuracies for image recognition due to a novel approach called shortcut connections allowing the network to be deeper with a reduced risk of exploding or vanishing gradients, which typically is a nuisance when training very deep neural networks.

The network architecture can be found below:

Resnet architecture implemented using the tf.keras API

The Notebook

The Google Colab notebook containing all steps can be found HERE.

In order to run the Notebook on a Google Colab GPU, you should go to Runtime->Select Runtime and choose Python 3 as well as ‘GPU’ from the Dropdown menu.

Now you can execute the Notebook by pressing Cmd+F9.

Results

After training the network for 10 epochs, we arrive at a validation loss of about 1.4956 and a validation accuracy of about 72.86%.

Below, you can see visualizations of the training loss and training accuracy over the number of training steps (heavy smoothing is applied):

Image for post
Image for post
Training loss
Image for post
Image for post
Training accuracy

Using the trained classifier for predicting the classes of the test set images, we arrive at the following confusion matrix:

Image for post
Image for post
Confusion matrix of CIFAR-10 classifier

A majority of the labels are being predicted correctly. The more interesting part about the confusion matrix, however, are the mistakes the classifier makes. It occasionally seems to confuse airplanes with birds or automobile with trucks or dogs with cats. These three pairs of classes have many similarities. Airplanes and birds have wings and a elongated body in between. Trucks and automobiles have undoubtedly very similar features making it hard for the classifier to distinguish between the two classes. The same applies to dogs and cats.

On the NVIDIA Tesla K80, the training of the classifier should take about 7 minutes per epoch. With the hyper-parameters selected in the notebook, we already get some pretty good results. However, we are not even close to the state-of-the-art in image recognition which is currently (as of Nov 2018) at a whopping 96.53% Top-1 accuracy.

Feel free to play around with the hyper-parameters and with the notebook itself in order to find the best set of hyper-parameters. You can fork the notebook on https://github.com/jzuern/cifar-classifier.

Thanks for reading and happy classifying! 🎉

PhD student in robotics

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store