MNIST is like the secret handshake we use to introduce people into the world of machine learning. Almost all of us have taken a shot at classifying the squiggly handwritten letters in this dataset at some point in our ML careers.
In this post we dive deeper into a state of the art model that achieves 99.4% validation accuracy on MNIST, and try to figure out what kinds of images leave it flummoxed! We hope this analysis will be useful to people trying to build models with near-perfect MNIST accuracy.
If you’re curious, grab yourself a coffee and join the quest! You can find the Weights and Biases dashboard used to generate these insights here.
I encourage you to dive deeper into your own models’ predictions using W&B! It’s easy to get started! Good luck!
Trends and Summary
The primary results of our exploration are summarized in the table on the left.
- As we can see, the model classifies over 99% of the digits in each class.
- A small minority of the handwritten digits that the model has trouble correctly classifying are truly hard for even humans to classify. For example, the if we look at True Class 6 in the table, the misclassified images do reasonably look like 0s, 4s and 8s.
- The majority of the other classes, for e.g. the instances in True Class 8, are ones that seem very easy for our human brains to classify (granted we’ve got a million years of evolution on our side.) So why do these seem non-obvious to the model?
- In the next section, we’ll look at what parts of the image the model focusses on to classify each image to gain insight into why it misclassifies these!
Let’s put our data detective hat on and dig a little deeper! We’ll start by plotting the saliency heatmaps, which help us visualize how influential each pixel is with respect to the predicted class – they basically tell us which regions of the image contribute the most to the predicted output.
- On the left, we see the saliency maps for each of the 10 classes. The pixels in red represent positive correlations; the ones in blue negative correlation; and the ones in green have no correlation.
- On the right, we create saliency maps for some of the images that were misclassified by the model.
- Let’s dig deeper into the second image, the 1 misclassified as 2: We can see that the cluster of pixels that form a little tail on the bottom right of the image is a strong indicator of the number 2 where these tails are the most common.
- A very small portion of the 1s in the dataset have that little tail, therefore the model diligently classifies the image as a 2.
- This is an excellent insight and one of the ways we can improve the model is by collecting more examples of 1s with this tail.
- Similarly, for the bottom right image, the 7 misclassified as a 9, the model doesn’t have a lot of examples of 7s with a rounded “head”, which is more common in 9s. Adding more examples of 7s with a rounded add would make the model more a more discerning classifier of 7s and 9s.
- Saliency maps serve as excellent tools for debugging the model and improving accuracy even for SOTA models!
- Weights and Biases let’s you generate a confusion matrix automatically when you set
log_evaluation=Truein your wandb callback!
- This is very helpful in identifying at a glance the distributions of the misclassified instances, and thereby fix the root causes.
- Hovering over any cell in the confusion matrix on the W&B dashboard also shows you examples of what these misclassified instances look like.
- We can use the confusion matrix to determine what (true_class, predicted_class) pairs we should pay the most attention to.
- Creating saliency maps for the most misclassified instances, as determined by the confusion matrix, and finding patterns of errors in them would allow us to determine the flavor of dataset augmentations we need to undertake to improve the accuracy of the model.
Intermediate Class Activations
To round out the analysis, we analyze the intermediate class activations which are quite useful for understanding what features successive layers of the CNN extract from the input image. The key takeaways from the maps are:
- The first layers retain almost all information from the input image, which is likely because they’re learning primitive build-block features, like edges, which exist in large parts of the image.
- The more deeper layers in the network start to take inputs from the first layers and combine them to learn edges, corners, angles and shapes. We see the activations look somewhat more abstract here and therefore it’s a lot harder to visually interpret them.
- The deepest layers retain the least information about the contents of the image. Instead they start to encode features that correlate to the class the image belongs to. The black squares in the filters refer to complex patterns learnt by the neural network that are not present in this image.
Digging Deeper with Data Frames
- One of the most powerful things Weights and Biases lets you do is dive very deep into the predictions made by your model in the dashboard.
- Here we see all the predictions grouped first by true_class, then by predicted_class. This means we can, for instance, look at all the instances of 9s in the validation dataset and see the distribution of predictions made by the model.
- We can see every instance the model got wrong, how confident it was in its prediction, and the probabilities of the other classes it considered.
- This level of granularity is a delight when it comes to debugging, specially because neural networks can notoriously be such black-boxes!
Lastly, we track the validation and accuracy metrics for the training and validation sets, at each epoch. The model’s final validation accuracy was 99.48%. We could potentially have used EarlyStopping here and truncated training at around epoch 15 when we reached 99.3% val accuracy!