« Smerity.com

Question answering on the Facebook bAbi dataset using recurrent neural networks and 175 lines of Python + Keras

One of the holy grails of natural language processing is a generic system for question answering. The Facebook bAbi tasks are a synthetic dataset of 20 tasks released by the Facebook AI Research team that help evaluate systems hoping to do just that.

An example from the second task, Two Supporting Facts (QA2), is below:

1 John moved to the bedroom.
2 Mary grabbed the football there.
3 Sandra journeyed to the bedroom.
4 Sandra went back to the hallway.
5 Mary moved to the garden.
6 Mary journeyed to the office.
7 Where is the football? office 2 6

Whilst this may seem trivial to you, it can represent quite a challenge even for advanced machine learning models. The bAbi tasks cover far more than trivial comprehension however - they're supposed to represent a prerequisite towards an AI-Complete question answering solution. Each task aims to require a unique aspect of text and reasoning, testing the different capabilities of the learning models. To answer the questions correctly, the models must be able to perform induction, deduction, fact chaining, and more.

Whilst doing well on this task requires advanced tools, we can implement a baseline solution in only a few lines using the Keras machine learning library. The results are comparable (and occasionally superior) to those for the LSTM baseline provided in Weston et al.'s Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks given only 1000 samples and without any hyperparamater tuning.

Try it yourself!

The question answering code this article refers to is now part of the Keras distribution: babi_rnn.py in the examples directory. When you run the example, Keras will automatically download the dataset and start training.

Best of all, as opposed to most deep learning tasks, training these models will only take a few minutes each!

Why a synthetic dataset?

As bAbi is a synthetic dataset, you may ask why we're interested in doing well on it, or why we even created it at all.

Real world data is noisy. Rarely does it provide a clear and simple answer for you to train on. Additionally, even a well curated dataset from the real world is littered with nuance, complexities, and errors.

Instead of relying on real world data, we can instead challenge the machine learning models using simulations reminiscent of classic text adventure games. The tasks are generated using a simulation reminiscent of a classic text adventure game. By using an artificial world we know the exact state the world is in and the exact set of rules by which it runs. Thanks to this, generating training and testing data is trivial.

As opposed to real world material, the data is also well curated. The vocabulary (set of words) is constrained, the sentences are always well structured (the only noise is the noise we want to challenge the model with), and the performance on specific tasks can be tested without other tasks interferring. As we know the exact state of the world and how it got to that point, we can also provide additional helpful information, such as pointing out precisely how the answer can be reached (the supporting facts in bold above).

With the synethetic dataset, all the commonsense knowledge and reasoning required for the test set should be contained in the training set. That way, if a machine learning model then fails to solve the task, we know that the challenge is in the model itself, and not the data (or lack of data) it was exposed to.

How do we approach the problem?

One of the easiest ways to approach a task is to code a basline solution. Baseline solutions are meant to provide the best "bang for the buck" - the minimal amount of work for the best result possible. In this situation, a recurrent neural network (RNN) is the baseline we can turn to.

Recurrent neural networks such as Long Short Term Memory (LSTM) and the Gated Recurrent Unit (GRU) are neural networks that can process a sequence of inputs, updating the network's internal state as it reads more data. This enables it to learn long term dependencies such as bracket matching. As we encode words into a vector representation, we can consider a sentence as a sequence of words, feeding them one at a time into our RNN.

Instead of implementing these models ourselves, we can instead use an existing implementation.

Keras - a deep learning library for Python

Keras is a minimalist, highly modular neural network library in the spirit of Torch, written in Python, that uses Theano under the hood for optimized tensor manipulation on GPU and CPU.

As an open source project, it has strong documentation, an active community, and a good leader. Only three hours after submitting my pull request for this example code, François Chollet (fchollet) merged in the code. The rapid turn around of the project and strong examples make it a good library to get going with deep learning. Keras also leverages the Theano library, a Python library defining, optimizing, and evaluating mathematical expressions involving multi-dimensional arrays efficiently.

The challenge at hand

Our idea is as follows: each task has a story component and a query component. We will run an RNN over both of these components, converting the long sequence of words into a fixed vector representation. This fixed vector representation should hopefully encapsulate all of the relevant input. Finally, we feed these two fixed vector representations into a traditional dense neural network, where it can look at the encoded query, then at the encoded story, and hopefully answer the question correctly.

Word vectors

One additional component is the word vector representation. This is where hope to convert a word into a fixed vector representation encapsulating extra knowledge about it. Word vectors hope to capture the meaning behind the word, enabling related words to be considered similar and thus act in similar fashions.

This might be important if, for example, we had the two sentences:
John put down the apple.
John dropped the apple.
but are only interested in answering the question "Does John have the apple?", where the nuance between putting something down and dropping it is unimportant.

Whilst we can learn good word vector representations for the small set of words in these task, it wouldn't have broader knowledge. For a real task, where knowing extra information might be useful (such as frog ~= toad), we could use existing word vectors trained on billions of words, such as the Stanford's GloVe.

The code

Luckily, the code for this is stunningly simple thanks to Keras. You can see the full code at babi_rnn.py but the relevant recurrent network code is quickly and minimally contained below.

sentrnn = Sequential()
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True))
sentrnn.add(RNN(EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, return_sequences=False))

qrnn = Sequential()
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE))
qrnn.add(RNN(EMBED_HIDDEN_SIZE, QUERY_HIDDEN_SIZE, return_sequences=False))

model = Sequential()
model.add(Merge([sentrnn, qrnn], mode='concat'))
model.add(Dense(SENT_HIDDEN_SIZE + QUERY_HIDDEN_SIZE, vocab_size, activation='softmax'))

Final results

In this section, I compare the final results for the Keras based question answering system with the LSTM baseline provided by the Facebook paper.

The results are comparable (and occasionally superior) to those for the LSTM baseline provided in Weston et al.'s Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks given only 1000 samples and without any hyperparamater tuning. The same model is also used across all tasks.

Unfortunately, the baseline is just that. Using traditional recurrent neural networks, such as the LSTM or GRU, won't give you substantially better performance even if you scale up the network tremendously. For better results, new neural network configurations have been suggested and used, such as Facebook's Memory Network (further improved in the paper presenting the bAbi dataset), Google's Neural Turing Machine, and MetaMind's Dynamic Memory Networks.

All of these models can take advantage of knowing where the supporting facts are, learning where to focus attention in the input, and performing multiple "lookups" to track down relevant information. I'm hoping to implement a simple version of one of these models in the near future.

For now, however, I'm content with my simple baseline.

Task Number FB LSTM Baseline Keras QA
QA1 - Single Supporting Fact 50 52.1
QA2 - Two Supporting Facts 20 37.0
QA3 - Three Supporting Facts 20 20.5
QA4 - Two Arg. Relations 61 62.9
QA5 - Three Arg. Relations 70 61.9
QA6 - Yes/No Questions 48 50.7
QA7 - Counting 49 78.9
QA8 - Lists/Sets 45 77.2
QA9 - Simple Negation 64 64.0
QA10 - Indefinite Knowledge 44 47.7
QA11 - Basic Coreference 72 74.9
QA12 - Conjunction 74 76.4
QA13 - Compound Coreference 94 94.4
QA14 - Time Reasoning 27 34.8
QA15 - Basic Deduction 21 32.4
QA16 - Basic Induction 23 50.6
QA17 - Positional Reasoning 51 49.1
QA18 - Size Reasoning 52 90.8
QA19 - Path Finding 8 9.0
QA20 - Agent's Motivations 91 90.7

Note: Dataset issues - duplication in Positional Reasoning (QA17) and Size Reasoning (QA18)

The results above show a large performance difference between the Facebook LSTM baseline and the Keras QA system on QA18 - jumping from 52 to 91. Whilst investigating I found that there were numerous duplicated statements and questions in the QA18 training and testing datasets. This is also an issue in QA17 and possibly others. Given that there are only 1000 train and test data points (which you can confirm by running grep "?" tasks_1-20_v1-2/en/qa18_size-reasoning_train.txt | wc -l), repetitions could cause serious issues.

I'll be emailing the maintainers of the dataset once I perform a full analysis in the hopes this will be fixed for Version 1.3 of the data.

# QA18 - top of training data
1 The container fits inside the suitcase.
2 The container fits inside the suitcase.
3 The chocolate fits inside the chest.
4 The box of chocolates fits inside the suitcase.
5 The chocolate fits inside the box.
6 The chocolate fits inside the container.
7 The box of chocolates fits inside the suitcase.
8 The container fits inside the suitcase.
9 The chocolate fits inside the box.
10 The suitcase is bigger than the chocolate.
11 Is the chocolate bigger than the suitcase? no 6 1
12 Does the suitcase fit in the chocolate? no 6 1
13 Does the suitcase fit in the chocolate? no 6 1
14 Does the suitcase fit in the chocolate? no 6 1
# QA17 - top of training data
1 The triangle is above the pink rectangle.
2 The blue square is to the left of the triangle.
3 Is the pink rectangle to the right of the blue square? yes 1 2
4 Is the blue square below the pink rectangle? no 2 1
5 Is the blue square to the right of the pink rectangle? no 2 1
6 Is the blue square below the pink rectangle? no 2 1
7 Is the blue square below the pink rectangle? no 2 1

Update - Full duplicate analysis

Having finished the duplicate analysis, there are issues in the dataset that need to be fixed. Luckily the dataset has already been released in a versioned state, though it is unfortunate that the papers published using the dataset do not report which versions they used, and historical versions are not available.

The duplicate analysis was performed by finding only unique (story, query, answer) tuples within the training set and the test set, then finding if there were any intersections between those unique tuples.

The most extreme issue is that one of the tasks, QA4, has about 13% of the unique samples present in both training and testing.

Unique samples in QA4 - Two Arg Relations
Train length: 919
Test length: 933
Intersection: 124

Another issue was duplicates within the training and testing sets, especially problematic in QA15, QA17, and QA18. This is especially important given that the algorithms are trained on only 1000 samples.

QA15, QA17, and QA18 has numerous duplicates in training and testing

Unique samples in QA15 - Basic Deduction
Train length: 695
Test length: 678
Intersection: 0
Unique samples in QA17 - Positional Reasoning
Train length: 627
Test length: 632
Intersection: 11
Unique samples in QA18 - Size Reasoning
Train length: 654
Test length: 602
Intersection: 0

These issues become even more extreme when the bAbi tasks contain 10,000 samples are used.

For details, refer to the full results.