One of the holy grails of natural language processing is a generic system for question answering. The Facebook bAbi tasks are a synthetic dataset of 20 tasks released by the Facebook AI Research team that help evaluate systems hoping to do just that.
An example from the second task, Two Supporting Facts (QA2), is below:
Whilst this may seem trivial to you, it can represent quite a challenge even for advanced machine learning models. The bAbi tasks cover far more than trivial comprehension however - they're supposed to represent a prerequisite towards an AI-Complete question answering solution. Each task aims to require a unique aspect of text and reasoning, testing the different capabilities of the learning models. To answer the questions correctly, the models must be able to perform induction, deduction, fact chaining, and more.
Whilst doing well on this task requires advanced tools, we can implement a baseline solution in only a few lines using the Keras machine learning library. The results are comparable (and occasionally superior) to those for the LSTM baseline provided in Weston et al.'s Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks given only 1000 samples and without any hyperparamater tuning.
The question answering code this article refers to is now part of the Keras distribution: babi_rnn.py in the examples directory. When you run the example, Keras will automatically download the dataset and start training.
Best of all, as opposed to most deep learning tasks, training these models will only take a few minutes each!
As bAbi is a synthetic dataset, you may ask why we're interested in doing well on it, or why we even created it at all.
Real world data is noisy. Rarely does it provide a clear and simple answer for you to train on. Additionally, even a well curated dataset from the real world is littered with nuance, complexities, and errors.
Instead of relying on real world data, we can instead challenge the machine learning models using simulations reminiscent of classic text adventure games. The tasks are generated using a simulation reminiscent of a classic text adventure game. By using an artificial world we know the exact state the world is in and the exact set of rules by which it runs. Thanks to this, generating training and testing data is trivial.
As opposed to real world material, the data is also well curated. The vocabulary (set of words) is constrained, the sentences are always well structured (the only noise is the noise we want to challenge the model with), and the performance on specific tasks can be tested without other tasks interferring. As we know the exact state of the world and how it got to that point, we can also provide additional helpful information, such as pointing out precisely how the answer can be reached (the supporting facts in bold above).
With the synethetic dataset, all the commonsense knowledge and reasoning required for the test set should be contained in the training set. That way, if a machine learning model then fails to solve the task, we know that the challenge is in the model itself, and not the data (or lack of data) it was exposed to.
One of the easiest ways to approach a task is to code a basline solution. Baseline solutions are meant to provide the best "bang for the buck" - the minimal amount of work for the best result possible. In this situation, a recurrent neural network (RNN) is the baseline we can turn to.
Recurrent neural networks such as Long Short Term Memory (LSTM) and the Gated Recurrent Unit (GRU) are neural networks that can process a sequence of inputs, updating the network's internal state as it reads more data. This enables it to learn long term dependencies such as bracket matching. As we encode words into a vector representation, we can consider a sentence as a sequence of words, feeding them one at a time into our RNN.
Instead of implementing these models ourselves, we can instead use an existing implementation.
As an open source project, it has strong documentation, an active community, and a good leader. Only three hours after submitting my pull request for this example code, François Chollet (fchollet) merged in the code. The rapid turn around of the project and strong examples make it a good library to get going with deep learning. Keras also leverages the Theano library, a Python library defining, optimizing, and evaluating mathematical expressions involving multi-dimensional arrays efficiently.
Our idea is as follows: each task has a story component and a query component. We will run an RNN over both of these components, converting the long sequence of words into a fixed vector representation. This fixed vector representation should hopefully encapsulate all of the relevant input. Finally, we feed these two fixed vector representations into a traditional dense neural network, where it can look at the encoded query, then at the encoded story, and hopefully answer the question correctly.
One additional component is the word vector representation. This is where hope to convert a word into a fixed vector representation encapsulating extra knowledge about it. Word vectors hope to capture the meaning behind the word, enabling related words to be considered similar and thus act in similar fashions.
This might be important if, for example, we had the two sentences:
John put down the apple.
John dropped the apple.
but are only interested in answering the question "Does John have the apple?", where the nuance between putting something down and dropping it is unimportant.
Whilst we can learn good word vector representations for the small set of words in these task, it wouldn't have broader knowledge. For a real task, where knowing extra information might be useful (such as frog ~= toad), we could use existing word vectors trained on billions of words, such as the Stanford's GloVe.
Luckily, the code for this is stunningly simple thanks to Keras. You can see the full code at babi_rnn.py but the relevant recurrent network code is quickly and minimally contained below.
sentrnn = Sequential() sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True)) sentrnn.add(RNN(EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, return_sequences=False)) qrnn = Sequential() qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE)) qrnn.add(RNN(EMBED_HIDDEN_SIZE, QUERY_HIDDEN_SIZE, return_sequences=False)) model = Sequential() model.add(Merge([sentrnn, qrnn], mode='concat')) model.add(Dense(SENT_HIDDEN_SIZE + QUERY_HIDDEN_SIZE, vocab_size, activation='softmax'))
In this section, I compare the final results for the Keras based question answering system with the LSTM baseline provided by the Facebook paper.
The results are comparable (and occasionally superior) to those for the LSTM baseline provided in Weston et al.'s Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks given only 1000 samples and without any hyperparamater tuning. The same model is also used across all tasks.
Unfortunately, the baseline is just that. Using traditional recurrent neural networks, such as the LSTM or GRU, won't give you substantially better performance even if you scale up the network tremendously. For better results, new neural network configurations have been suggested and used, such as Facebook's Memory Network (further improved in the paper presenting the bAbi dataset), Google's Neural Turing Machine, and MetaMind's Dynamic Memory Networks.
All of these models can take advantage of knowing where the supporting facts are, learning where to focus attention in the input, and performing multiple "lookups" to track down relevant information. I'm hoping to implement a simple version of one of these models in the near future.
For now, however, I'm content with my simple baseline.
|Task Number||FB LSTM Baseline||Keras QA|
|QA1 - Single Supporting Fact||50||52.1|
|QA2 - Two Supporting Facts||20||37.0|
|QA3 - Three Supporting Facts||20||20.5|
|QA4 - Two Arg. Relations||61||62.9|
|QA5 - Three Arg. Relations||70||61.9|
|QA6 - Yes/No Questions||48||50.7|
|QA7 - Counting||49||78.9|
|QA8 - Lists/Sets||45||77.2|
|QA9 - Simple Negation||64||64.0|
|QA10 - Indefinite Knowledge||44||47.7|
|QA11 - Basic Coreference||72||74.9|
|QA12 - Conjunction||74||76.4|
|QA13 - Compound Coreference||94||94.4|
|QA14 - Time Reasoning||27||34.8|
|QA15 - Basic Deduction||21||32.4|
|QA16 - Basic Induction||23||50.6|
|QA17 - Positional Reasoning||51||49.1|
|QA18 - Size Reasoning||52||90.8|
|QA19 - Path Finding||8||9.0|
|QA20 - Agent's Motivations||91||90.7|
The results above show a large performance difference between the Facebook LSTM baseline and the Keras QA system on QA18 - jumping from 52 to 91.
Whilst investigating I found that there were numerous duplicated statements and questions in the QA18 training and testing datasets.
This is also an issue in QA17 and possibly others.
Given that there are only 1000 train and test data points (which you can confirm by running
grep "?" tasks_1-20_v1-2/en/qa18_size-reasoning_train.txt | wc -l), repetitions could cause serious issues.
I'll be emailing the maintainers of the dataset once I perform a full analysis in the hopes this will be fixed for Version 1.3 of the data.
Having finished the duplicate analysis, there are issues in the dataset that need to be fixed. Luckily the dataset has already been released in a versioned state, though it is unfortunate that the papers published using the dataset do not report which versions they used, and historical versions are not available.
The duplicate analysis was performed by finding only unique (story, query, answer) tuples within the training set and the test set, then finding if there were any intersections between those unique tuples.
The most extreme issue is that one of the tasks, QA4, has about 13% of the unique samples present in both training and testing.
Another issue was duplicates within the training and testing sets, especially problematic in QA15, QA17, and QA18. This is especially important given that the algorithms are trained on only 1000 samples.
QA15, QA17, and QA18 has numerous duplicates in training and testing
These issues become even more extreme when the bAbi tasks contain 10,000 samples are used.
For details, refer to the full results.
Interested in saying hi? ^_^