The
Neural networks are often the best-performing classifiers in machine learning tasks. Unfortunately, their results can be notoriously hard to interpret, their training process can be hard to debug, and comparing different neural network models is non-trivial. We present a method to visualize the responses of a neural network using a somewhat-forgotten visualization method called the grand tour.
There is a number of great videos and blogs that explain what a neural network is (for example the one made by 3Blue1Brown).
There are also numerous sources explaining different aspects of understanding the internal of a neural network.
The
Despite the various aspects of "understanding a neural network", in this work we will demonstrate a tool for understanding the resultant behavior and performance of neural networks. We first give some necessary background of a neural network. Then we explain the technique we adapt called the grand tour, and describe the novel interaction we added onto it for direct user manipulation.
At a very high level, a neural network is a black box that answers a very specific type of question. Taking a classic example, if we want a machine that helps us read hand written digits, a neural network can do it. The process of automatically tuning a neural network for a specific use (in this case, recognizing digits) is called training. A complete training round during which the network sees all training examples once is often referred to as an epoch.
As you will see in a later demonstration, the trained network may make incorrect claims about some images. This is mostly because the machine only see finite number of training examples (in training dataset) but we test the machine with new, unseen examples. To test how the machine perform on real data, we often prepare a separated testing dataset that machine has never seen in training. We could feed our testing examples into the neural network, one at a time, to see if any of them has a wrong answer. However, to diagnose the overall performance, we shall not only test with one example at a time. Instead we often look at a summary of many testing examples. Following this principle we often group all testing images by a pair of their properties - (true_class, predicted_class). Presumably the network could perform well on some easy classes but perform bad on some other hard classes. In that's the case we will see interesting counting patterns among different groups. If we count the number of images that falls into different groups, we end up what is called a confusion matrix. See the figure below as an example of confusion matrix for our classifier, shown as a heatmap. The more yellow a cell has, the more image examples fall into the corresponding group, i.e. the higher count. The color map are consistent through epochs, normalized by maximum of off-diagonal entries from epoch 25 to epoch 99. Under this design, most diagonal entries are likely to saturate because the count of correct predictions is often more than wrong predictions on any off-diagonal cells. You can click play button or drag the slider bar handle to see how network performance improved throughout training epochs. Hovering over a cell will show the actual count in the group.
We can see interesting things happening. For example, when looking at epoch 19 of MNIST dataset, we notice that the network incorrectly predicts many digit 7 images to classes of digit 0,1,2,3 and 9. However on epoch 21, those wrong predictions are fixed. Although confusion matrix gives a general summary of how our neural network behaves, we have no means of further investigate the behavior of individual images, because we lose the information about how individual images contributed to this summary. That is why a simple confusion matrix can never answer questions like "which image is hard to predict". We would benefit from a more detailed instance-by-instance look about our network.
For the purpose of understanding the internal of a neural network, we have to open the black box and look at how it works on individual images.
Similar illustrations exist, for example, by
Mathematically, a neural network can be represented as a function. To make a neural network well modularized and easy to build and train, internally the mathematical function is composed by multiple simpler building-block functions, which are often called layers in neural network context. Such building-block functions serve different purposes. The labels on each colored rectangle denote what kind of building-block functions were used in our network. When an example image were passed through the network, each building-block function transforms the image to a different numerical representation and as the functions works in a chain, the image is transformed from one representation to another. The representations can be in the form of a scale, vector or tensor. In the diagram above, we illustrated some of these representations by gray-scale heatmaps. The whiter a pixel is, the larger value stored in the corresponding entry. To better use the drawing space, we reshaped these representation closer to squares. A worth noting representation is the 10-vector produced in the last softmaxlayer. In this layer, the networks is trained to represent an one-hot encoding of the image class label, so normally only the entry corresponding to the true class of the image are fired up. For example, a digit 0 should have a white spot in the first column, first row. A digit 1 image should fire up the cell in first column, second row.
One nice property of the softmax layer is that each entries are non-negative and the sum of them equals to 1. This property gives us an attractive interpretation of this layer - for any given image, the final softmax layer represents the confidence of the network about this image belonging to each class. In stead of looking at one image at a time, we would like to see their distribution in this nice interpretable softmax layer. To see how this 10-vector is activated by the set of images, we turned to an old technique called the grand tour to reveal spatial relations among dimensions.
In the grand tour, we regard a data point with
If the data point has
Thinking in terms of what the grand tour does to high dimensional data points, we should be able to examine neuron activations of any layer in a deep neural network. However, the final softmax layer is the only hidden layer that we can directly associate a meaning for its axes - the
We experimented with 3 common 10-class datasets: MNIST digit, fashion-MNIST and CIFAR10.
Our models, as shown in the network structure figure, are way simpler in structure than any state-of-the-art ones that were used on the same task.
We trained 99 epochs, computed the grand tour of the activation in the final softmax layers for every snapshots,
and in animation, we linearly interpolate corresponding points between two consecutive epochs.
For this classification task, our neural networks aim to match an image to a 'one-hot' encoding of its corresponding class after the softmax layer.
Natually, if we treat the one-hot encoding as a 10-dimensional vector, in the ideal case we want the neural network cluster instances of
First we can observed different levels of difficulty on different datasets by looking at the grand tour plot of them. The MNIST digit dataset is known as the easiest dataset among the three. When looking at MNIST (click this text to see) dataset, we see that most of the digits are pushed away from the origin and directly to a corner of the 10-simplex, while on the other two datasets (fashion MNIST, cifar10) the pushing were not as extreme and more points lives in between two corners. That is revealing that the input confused the neural network model more on the two challenging datasets.
Next we can easily locate 2-way and 3-way confusions in the trained fashion-MNIST and cifar10 dataset. When we plot the grand tour of fashion-MNIST dataset, we find out that the fashion-MNIST model confused between sandals and sneakers, sneakers and ankle boots but not sandal and ankle boots. We can see two lines in the grand tour plot fulfilled by sandal and sneaker examples, sneakers and ankle boots, while fewer points lives close to the midpoint of sandal and ankle boot classes. Within the same dataset, we can also see a triangular plane filled in by examples from pullover, coat and shirt. This means that for some of those example images, the model has similar confidence level among three classes, so the images span the whole interior of a triangular face. From this observation we can confirm that our model is well-challenged with distinguishing these three categories, making sense to us because the pullovers, coats and shirts have similar shapes and are hard to distinguish, even by human. When we saw the images of all trousers examples get well clustered to the trousers corner (hover over the legend to confirm it), we hypothesized that the trousers class would have best classification precision among all classes. Indeed, among the 1000 test examples we showed in grand tour, trousers class have greatest precision 0.981(103/105). Among all 10000 testing data of the fashion-mnist, the trousers class also get into top 2 (precision 0.955, with best classified class sandal has precision 0.959) in terms of precision of predictions. In cifar10 dataset we can see confusion between dogs and cats , airplane and ships . However, since our model for cifar10 was not well trained as MNIST and fashion-MNIST, the confusion was not as obvious.
Moreover, looking at the temporal dynamics of the training give us some clues about how the network was trained. When we look at the MNIST dataset, the images went directly toward the axis of its true class and got stabilized after 50 epochs . On the other hand, the test images in cifar10 keep oscillating between two or more classes , which hinted us that the model can not fit all examples at once, suggesting that our model may run into a local optimum or is over-fitting.
Although with the grand tour we have observed interesting phenomenon, there are other ways of looking at the same kind of data.
For example, the
In addition to the original grand tour, we incorporate a novel handle-dragging interaction to our plot. When user drags any one of the colored bubbles on each tip of class axes, he/she changes moved the axis to the desired place. With this interaction, when grand tour is animating, the user suggests a different starting place for the animation. When grand tour is paused, the user has full control of defining the linear projection used by direct manipulation on the axes.
In the next section, we will show small multiple of random linear projections served as key frames of grand tour instead of using the actual animations. Animations are some time not ideal, as viewers have to stare on the playback for a long time in order to digest the important patterns shown by the time dimension.
However, in our specific case, we value the animation a little more than static plots. We all know that since the projection is linear, points that form a straight line is still a straight line in the low dimensional projection. However, because the mapping is from a high dimensional space to a lower one, a straight line on screen may not be a real line - we may be seeing an illusion created by the information-losing projection. With static plots, this wondering of illusion can never be eliminated, since we only see one projection . With an animation that shows multiple projections, it is a lot easier to visually infer and verify the existence of line in high dimensional space - as animation time goes, the probability of not aligned data points forming a line in 2D screen decays monotonically toward 0. This animated view therefore forms a kind of visual hypothesis testing (with null hypothesis being "The line I see is created by projection, not the data"). As you may be convinced by the figure in the next section, it is harder to see lines with high confidence in static plots than in grand tour animation, not to mention people's confidence when seeing more complicated structures with a single projection.
The math behind the axis dragging is straightforward. The key idea is that when user is dragging an axis tip, he/she directly specifies where the linear transformation should map that axis on the screen. However, his/her specification may not correspond to a valid orthogonal projection. Therefore each time user make a small change on canvas, we pull his/her specification back to a nearest orthogonal matrix.
As a real algorithm, let us represent data points as row vectors with
Then, to make the result a valid rotation matrix, we simply do
Gram-Schmidt process
on the rows of the new user specified matrix
The JavaScript code running under this blog presentation is shown below.
Variable
In this section, we plot the final softmax layer with a number of other dimensionality reduction methods:
In addition, we use this section to highlight a current shortcoming of our present technique, namely that it only allows the visualization of a single projection. With multiple plots, we can use linked brushing and multiple views. This technique allows users to build an intuitive understanding of the projections by selecting points in one region of a plot and observing which points are highlighted on the other projections.
Both
PCA is a natural method for linear dimensionality reduction, since it picks the linear projection that preserves the most variance in the data. In the case of data from machine-learning tasks, however, we can expect the distribution of the data to be essentially spherical, since each class vector are orthogonal to one another in the dataset. As a result, even though the projections are interpretable and (mostly) consistent through training epochs, principal components tend to not be informative. Random projections tend to suffer from similar issues.
On the other hand, the grand tour offers the chance to see many projections smoothly animated. Combined with user interaction, it enables analysts to quickly select projections of interest that highlight important features in the data. In the projections we have selected for this comparison, those features are the two- and three-way confusions in the classifiers.
Best illustrated by fashion-MNIST dataset, while the three-way confusion among sandal, sneaker and ankle boot is most straight forwardly seen in the manual projection we picked with grand tour, it cast many doubt to claim the same thing when looking at t-SNE, UMAP or PCA. The UMAP captures this interesting confusion by grouping points from these three classes to a isolated cluster, but we can not claim with confidence that this is indeed a feature in data, instead of a feature in the UMAP embedding algorithm.
In MNIST dataset, we observed difference in "gotcha" moments of the neural network learning the 10 classes. For example, the learning moment of digit 1 is on epoch 14 , which happened far later than the majority of class labels. This interesting phenomenon was first observed in our animated grand tour plot, while it is difficult to see, even to verify, in other plots, especially in t-SNE and PCA.
Recent work on
Theoretically, we could have used the same technique to visualize any hidden layers in the networks, however, there are a few difficulties and technical challenges we hope to tackle in our future work. First, from the perspective of understanding neural networks, we have less contextual meaning in an intermediate hidden layer than in the softmax layer. The 10 dimensions in the softmax layer have a well-defined meaning - they represent the 'confidence' of network on each classes. If we were to plot a hidden layer which has, for example, 100 neurons, we have no clue what those 100 dimensions really represent. But with some special choice of hidden layers, we may have some other story to tell. For example, the grand tour can be directly useful to observe a latent space in auto-encoders. Another technical issue of visualizing more than 10 dimensions is that the computational complexity of plotting a grand tour of data would increase. We experimented that with 1000 data points that have more than 50 dimensions, we could not draw our grand tour animation with 60 fps with our current implementation. In such cases, dimensionality reduction might be considered useful but for the time being we do not know if it is the only and best choice.
In addition to comparing two neural networks, it is also interesting to us whether there exist categories in different runs of the same network structure. Do networks get trapped by the same local minimum? How many "traps" are there? Can we avoid them by improving the distribution of initial weights in networks? Those questions are of great interest to us and we would like to investigate in searching for some clues with the grand tour.
We host a separate web page for supplementary material here:
[link]
This material includes mathematics of the grand tour and more comparison with other dimensionality reduction techniques.