In this post we'll be pulling apart Breaking the Softmax Bottleneck: A High-Rank RNN Language Model by Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, and William W. Cohen.
The softmax operation is fundamentally important for many tasks in machine learning. The softmax allows you to produce a probability distribution over a set of classes - the exact type of thing you might want to do when deciding which digit an image represents, whether the photo is of a cat or dog, and which out of a dictionary of words is the right one to continue a sentence with.
Given that, it's terrifying to think we've been living with a bottlenecked softmax in many situations.
The traditional adage in deep learning tells you to increase your model's parameters until it overfits. The issue is that, for the situation caused by the softmax bottleneck, it's not possible or even practical to fix this by simply throwing more parameters at the problem.
Experimentally: achieves the best results for word level language modeling on the Penn Treebank and WikiText-2 datasets given the same parameter limit when extending the AWD-LSTM-LM codebase (Smerity et al.) and dynamic evaluation code (Krause et al.) (MoS code).
Theoretically: highlights an immense deficiency of existing softmax based models through the guise of matrix factorization leading to a simple and readily experimentally verified proof mechanism - matrix rank.
The traditional softmax function is defined by:
where \(h\) is our hidden or context vector and \(w\) is a word vector.
By taking the dot product between the context / hidden vector \(h\) and a target word \(w\) we produce what we call a logit. This logit then competes with all other logits, producing a probability. This probability is a normalized probability distribution, meaning that \(\sum_x P_\theta(x|h) = 1\) (i.e. the probability mass is conserved at 1).
The paper motivates the deficiency of the current softmax by introducing language modeling as a matrix factorization problem. I'll admit that properly explaining this is beyond my cursory analysis but I'll give the general flow of the argument as best I can, hopefully motivating you to then read the paper in detail.
Imagine you have a set of context vectors \(H \in \mathbb{R}^{N \times d}\), a set of word vectors \(W \in \mathbb{R}^{V \times d}\) (with \(V\) the vocabulary size), and the real word probability of seeing a specific next word in a given context as the \(P^*(x|c)\) entries in the \(A \in \mathbb{R}^{N \times V}\) matrix. Essentially we're running our trained model against \(N\) samples from the real world where we know the correct answers.
We want to get as close to the real world probabilities of \(A\) as possible using our approximation \(A' = H W^T\) - which can readily be posed and understood as a matrix factorization problem.
Whenever we perform matrix multiplications, the rank of our matrix is limited, specifically that when performing a matrix multiplication on \(AB\), we have \(\text{rank}(AB) = \min(\text{rank}(A), \text{rank}(B))\). This means many of the simple rules of matrix factorization come through - specifically that the rank of \(H W^T\) is upper bounded by the embedding size \(d\).
If the real world matrix \(A\) is of a higher rank than \(d\) - which is likely the case given the complexity and intricacies of natural languge - we're limiting ourselves to a low rank approximation \(A'\) of the real world data. Note that the matrix \(A\) could attain a full rank of \(V\). This is a problem when we want to decide on which word to use from our vocabulary \(V\) and have \(V \gg d\).
A simple approach would be to just increase \(d\) until it's of comparable size to our vocabulary of size \(V\). Unfortunately the issues are two-fold. First, the number of parameters for the model would grow rapidly as we increased \(d\). The vocabulary of word level language models is easily into the tens of thousands, if not hundreds of thousands, and each word is represented by a word vector of size \(d\). Second, increasing the dimensionality of the word vectors beyond a certain point (usually \(d = 400\)) has not been found helpful in practice. This has been confirmed experimentally in language modeling and machine translation amongst other tasks.
How can we increase the rank of our resulting matrix with minimal work? The paper introduces the mixture of softmaxes (MoS) as a potential solution. This essentially boils down to performing \(K\) different softmaxes and mixing them - with some tricks to avoid excess parameter use.
The mixture of softmaxes can be seen as jointly training an ensemble of \(K\) different models with minimal overhead. An ensemble is just where you train multiple models and average their predictions, a practice which usually outperforms any single model. Jointly training is when all of these ensembled models are trained against a single loss, allowing the models to work together to avoid deficiencies.
All the weights, including the word embedding matrix \(W \in \mathbb{R}^{V \times d} \) (which is the largest of the matrices), are reused by each of the different components. This means that the parameter overhead when introducing the mixture of softmaxes is minimal. The amount of computation increases quite a large amount however - specifically it's as if we're computing \(K\) traditional softmaxes.
The only two weights introduced are:
As it's a mixture, and as \([\pi_0, \ldots, \pi_K]\) is the result of a softmax, \(\sum_k \pi_k = 1\) and \(\sum_x P_\theta(x|h) = 1\).
A method that looks superficially very similar but isn't is a mixture of contexts, where we produce \(h^T\) from a weighted mixture of \([h_1^T, \ldots h_K^T]\).
This method is faster - primarily as we don't have to compute \(K\) different softmaxes where the softmax can be one of the slowest operations - but doesn't actually increase the rank of our resulting matrix. The reason this is a trap is that it appears very close to the mixture of softmaxes but can be rewritten to be equivalent to a standard softmax (specifically by setting \(h' = \sum^K_k \pi_k h_k^T\)).
Within the paper, they demonstrate better results for three different tasks - two within language modeling (Penn Treebank and WikiText-2), the other in dialog (Switchboard). Beyond getting better results, they show that the mixture of softmaxes model produces a resulting matrix with higher rank. The matrix we analyze is the result of taking the probability distribution for each sample in a dataset and stacking them together. Hence, if we have \(V\) words in our vocabulary and \(S\) samples in our dataset, the matrix we analyze is \(\mathbb{R}^{S \times V}\).
We're going to perform a simple recreation of that here. To illustrate the problem and solution we'll set up a simple PyTorch experiment within a Jupyter Notebook (Colaboratory link).
There are only three components: a linear layer that projects the \(H\) dimensional context vector to the vocabulary size \(V\), an optional "mixer" that takes \(H\) and produces \(K\) different projections of the context vector \([h_0, \ldots, h_K]\), and \(S\) random context vector inputs samples. For this experiment we will asume that each mixture is equally weighted (i.e. \(\pi_k = \frac{1}{K}\)). None of these are trained, all starting from random initialization.
Within our experiment, we set \(H = 32, V = 1000, S = 2048\) and look at the rank of the resulting matrix that predicts the vocabulary distribution for each context vector. This resulting matrix is \((2048, 1000)\) - i.e. for each sample we have a probability distribution for the 1000 different vocabulary items. The full potential rank of this matrix is 1000 - so that's what we're hoping for with mixture of softmaxes!
Softmax Type | k=1 | k=2 | k=3 | k=4 | k=5 |
---|---|---|---|---|---|
Traditional softmax | 34 | 34 | 34 | 34 | 34 |
Mixture of Contexts (MoC) | 34 | 34 | 34 | 34 | 34 |
Mixture of Softmaxes (MoS) | 34 | 629 | 979 | 995 | 997 |
For those deeply interested, the matrix rank here is approximated and likely off by a little - especially as I believe the 34 should indeed be 32. Well ... maybe 33 if the bias allows for an extra ..? Who knows, I'm not smart enough, let's continue ;) The approximation issue is for the most part due to numerical tolerance issues.
As we can see quite readily, both the baseline and mixture of contexts maintain a low rank approximation. The mixture of softmaxes result achieves nearly full rank for the resulting matrix, showing the potential of the technique. Note that these models are not trained at all, simply a result of random initialization, so by default models using the mixture of softmaxes already have additional expressivity in the form of a higher rank result matrix.
How much of the performance of ensembles may be just increasing the rank of the prediction matrices? This applies equally for \(K\) different models of the same type as well as dynamic evaluation methods such as Krause et al style dynamic eval, pointer, or continuous cache methods.
Will we see the same benefits within machine translation as we see in language modeling, where both of the tasks can have softmaxes over a stupendously large vocabulary?
What is the value of \(K\) needed for achieving full rank if we have set values for \((d, V)\)? Is achieving full rank enough or like other deep learning tasks do we want to essentially overparameterize the model? In the paper they set \(K = 15\) but provide no justification or analysis of this.
How are architectures such as the Transformer Network either already employing the mixture of softmaxes inadvertently or could potentially employ it in components such as the Multi-Head Attention mechanism?
What other situations, beyond the softmax over the vocabulary, are we being bitten by these low rank approximations? Is low rank when computing large scale attention via softmax an issue?
Is there any way to minimize the computational overhead whilst maintaining the full rank advantage of the mixture of softmaxes?
The code for the mixture of softmaxes is available on GitHub and extends the AWD-LSTM-LM codebase that my colleagues and I released for PyTorch and the dynamic evaluation codebase of Krause et al.
There are certainly many more opportunities to use this - go forth and prosper =]
Interested in saying hi? ^_^
Follow @Smerity