aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Mesnard <thomas.mesnard@ens.fr>2016-03-01 00:27:15 +0100
committerThomas Mesnard <thomas.mesnard@ens.fr>2016-03-02 09:28:39 +0100
commitf31caf61be87850f3afcd367d6eb9521b2f613da (patch)
tree2bcceeb702ef0d35bfdc925977797c40290b6966
downloaddeepmind-qa-f31caf61be87850f3afcd367d6eb9521b2f613da.tar.gz
deepmind-qa-f31caf61be87850f3afcd367d6eb9521b2f613da.zip
Initial commit
-rw-r--r--LICENSE21
-rw-r--r--README.md21
-rw-r--r--__init__.py0
-rw-r--r--config/__init__.py0
-rw-r--r--config/deep_bidir_lstm_2x128.py37
-rw-r--r--config/deepmind_attentive_reader.py42
-rw-r--r--config/deepmind_deep_lstm.py33
-rw-r--r--data.py177
-rw-r--r--model/__init__.py0
-rw-r--r--model/attentive_reader.py152
-rw-r--r--model/deep_bidir_lstm.py109
-rw-r--r--model/deep_lstm.py99
-rw-r--r--paramsaveload.py37
-rwxr-xr-xtrain.py112
14 files changed, 840 insertions, 0 deletions
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..9bdd733
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2016 Thomas Mesnard
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..63f8142
--- /dev/null
+++ b/README.md
@@ -0,0 +1,21 @@
+DeepMind : Teaching Machines to Read and Comprehend
+=========================================
+
+This repository contains an implementation of the two models (the Deep LSTM and the Attentive Reader) described in *Teaching Machines to Read and Comprehend* by Karl Moritz Hermann and al., NIPS, 2015. This repository also contains an implementation of a Deep Bidirectional LSTM.
+
+Models are implemented using [Theano](https://github.com/Theano/Theano) and [Blocks](https://github.com/mila-udem/blocks). Datasets are implemented using [Fuel](https://github.com/mila-udem/fuel).
+
+The corresponding dataset is provided by [DeepMind](https://github.com/deepmind/rc-data) but if the script does not work you can check [http://cs.nyu.edu/~kcho/DMQA/](http://cs.nyu.edu/~kcho/DMQA/) by [Kyunghyun Cho](http://www.kyunghyuncho.me/).
+
+Reference
+=========
+[Teaching Machines to Read and Comprehend](https://papers.nips.cc/paper/5945-teaching-machines-to-read-and-comprehend.pdf), by Karl Moritz Hermann, Tomáš Kočiský, Edward Grefenstette, Lasse Espeholt, Will Kay, Mustafa Suleyman and Phil Blunsom, Neural Information Processing Systems, 2015.
+
+
+Credits
+=======
+[Thomas Mesnard](https://github.com/thomasmesnard)
+
+[Alex Auvolat](https://github.com/Alexis211)
+
+[Étienne Simon](https://github.com/ejls) \ No newline at end of file
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/__init__.py
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/config/__init__.py
diff --git a/config/deep_bidir_lstm_2x128.py b/config/deep_bidir_lstm_2x128.py
new file mode 100644
index 0000000..f07f43f
--- /dev/null
+++ b/config/deep_bidir_lstm_2x128.py
@@ -0,0 +1,37 @@
+from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping
+from blocks.initialization import IsotropicGaussian, Constant
+from blocks.bricks import Tanh
+
+from model.deep_bidir_lstm import Model
+
+
+batch_size = 32
+sort_batch_count = 20
+
+shuffle_questions = True
+shuffle_entities = True
+
+concat_ctx_and_question = True
+concat_question_before = True ## should not matter for bidirectionnal network
+
+embed_size = 200
+
+lstm_size = [128, 128]
+skip_connections = True
+
+n_entities = 550
+out_mlp_hidden = []
+out_mlp_activations = []
+
+step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=5e-5),
+ BasicMomentum(momentum=0.9)])
+
+dropout = 0.1
+w_noise = 0.05
+
+valid_freq = 1000
+save_freq = 1000
+print_freq = 100
+
+weights_init = IsotropicGaussian(0.01)
+biases_init = Constant(0.)
diff --git a/config/deepmind_attentive_reader.py b/config/deepmind_attentive_reader.py
new file mode 100644
index 0000000..84a6cf0
--- /dev/null
+++ b/config/deepmind_attentive_reader.py
@@ -0,0 +1,42 @@
+from blocks.bricks import Tanh
+from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping, Momentum
+from blocks.initialization import IsotropicGaussian, Constant
+
+from model.attentive_reader import Model
+
+
+batch_size = 32
+sort_batch_count = 20
+
+shuffle_questions = True
+
+concat_ctx_and_question = False
+
+n_entities = 550
+embed_size = 200
+
+ctx_lstm_size = [256]
+ctx_skip_connections = True
+
+question_lstm_size = [256]
+question_skip_connections = True
+
+attention_mlp_hidden = [100]
+attention_mlp_activations = [Tanh()]
+
+out_mlp_hidden = []
+out_mlp_activations = []
+
+step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=5e-5),
+ BasicMomentum(momentum=0.9)])
+
+dropout = 0.2
+w_noise = 0.
+
+valid_freq = 1000
+save_freq = 1000
+print_freq = 100
+
+weights_init = IsotropicGaussian(0.01)
+biases_init = Constant(0.)
+
diff --git a/config/deepmind_deep_lstm.py b/config/deepmind_deep_lstm.py
new file mode 100644
index 0000000..10b5c9b
--- /dev/null
+++ b/config/deepmind_deep_lstm.py
@@ -0,0 +1,33 @@
+from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping
+from blocks.initialization import IsotropicGaussian, Constant
+
+from model.deep_lstm import Model
+
+
+batch_size = 32
+sort_batch_count = 20
+
+shuffle_questions = True
+
+concat_ctx_and_question = True
+concat_question_before = True
+
+embed_size = 200
+
+lstm_size = [256, 256]
+skip_connections = True
+
+out_mlp_hidden = []
+out_mlp_activations = []
+
+step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=1e-4),
+ BasicMomentum(momentum=0.9)])
+
+dropout = 0.1
+
+valid_freq = 1000
+save_freq = 1000
+print_freq = 100
+
+weights_init = IsotropicGaussian(0.01)
+biases_init = Constant(0.)
diff --git a/data.py b/data.py
new file mode 100644
index 0000000..b3fa6d2
--- /dev/null
+++ b/data.py
@@ -0,0 +1,177 @@
+import logging
+import random
+import numpy
+
+import cPickle
+
+from picklable_itertools import iter_
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.schemes import IterationScheme, ConstantScheme
+from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer
+
+import sys
+import os
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+class QADataset(Dataset):
+ def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs):
+ self.provides_sources = ('context', 'question', 'answer', 'candidates')
+
+ self.path = path
+
+ self.vocab = ['@entity%d' % i for i in range(n_entities)] + \
+ [w.rstrip('\n') for w in open(vocab_file)] + \
+ ['<UNK>', '@placeholder'] + \
+ (['<SEP>'] if need_sep_token else [])
+
+ self.n_entities = n_entities
+ self.vocab_size = len(self.vocab)
+ self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)}
+
+ super(QADataset, self).__init__(**kwargs)
+
+ def to_word_id(self, w, cand_mapping):
+ if w in cand_mapping:
+ return cand_mapping[w]
+ elif w[:7] == '@entity':
+ raise ValueError("Unmapped entity token: %s"%w)
+ elif w in self.reverse_vocab:
+ return self.reverse_vocab[w]
+ else:
+ return self.reverse_vocab['<UNK>']
+
+ def to_word_ids(self, s, cand_mapping):
+ return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32)
+
+ def get_data(self, state=None, request=None):
+ if request is None or state is not None:
+ raise ValueError("Expected a request (name of a question file) and no state.")
+
+ lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))]
+
+ ctx = lines[2]
+ q = lines[4]
+ a = lines[6]
+ cand = [s.split(':')[0] for s in lines[8:]]
+
+ entities = range(self.n_entities)
+ while len(cand) > len(entities):
+ logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers"
+ %(len(cand), request))
+ entities = entities + entities
+ random.shuffle(entities)
+ cand_mapping = {t: k for t, k in zip(cand, entities)}
+
+ ctx = self.to_word_ids(ctx, cand_mapping)
+ q = self.to_word_ids(q, cand_mapping)
+ cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32)
+ a = numpy.int32(self.to_word_id(a, cand_mapping))
+
+ if not a < self.n_entities:
+ raise ValueError("Invalid answer token %d"%a)
+ if not numpy.all(cand < self.n_entities):
+ raise ValueError("Invalid candidate in list %s"%repr(cand))
+ if not numpy.all(ctx < self.vocab_size):
+ raise ValueError("Context word id out of bounds: %d"%int(ctx.max()))
+ if not numpy.all(ctx >= 0):
+ raise ValueError("Context word id negative: %d"%int(ctx.min()))
+ if not numpy.all(q < self.vocab_size):
+ raise ValueError("Question word id out of bounds: %d"%int(q.max()))
+ if not numpy.all(q >= 0):
+ raise ValueError("Question word id negative: %d"%int(q.min()))
+
+ return (ctx, q, a, cand)
+
+class QAIterator(IterationScheme):
+ requests_examples = True
+ def __init__(self, path, shuffle=False, **kwargs):
+ self.path = path
+ self.shuffle = shuffle
+
+ super(QAIterator, self).__init__(**kwargs)
+
+ def get_request_iterator(self):
+ l = [f for f in os.listdir(self.path)
+ if os.path.isfile(os.path.join(self.path, f))]
+ if self.shuffle:
+ random.shuffle(l)
+ return iter_(l)
+
+# -------------- DATASTREAM SETUP --------------------
+
+
+class ConcatCtxAndQuestion(Transformer):
+ produces_examples = True
+ def __init__(self, stream, concat_question_before, separator_token=None, **kwargs):
+ assert stream.sources == ('context', 'question', 'answer', 'candidates')
+ self.sources = ('question', 'answer', 'candidates')
+
+ self.sep = numpy.array([separator_token] if separator_token is not None else [],
+ dtype=numpy.int32)
+ self.concat_question_before = concat_question_before
+
+ super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs)
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError('Unsupported: request')
+
+ ctx, q, a, cand = next(self.child_epoch_iterator)
+
+ if self.concat_question_before:
+ return (numpy.concatenate([q, self.sep, ctx]), a, cand)
+ else:
+ return (numpy.concatenate([ctx, self.sep, q]), a, cand)
+
+class _balanced_batch_helper(object):
+ def __init__(self, key):
+ self.key = key
+ def __call__(self, data):
+ return data[self.key].shape[0]
+
+def setup_datastream(path, vocab_file, config):
+ ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question)
+ it = QAIterator(path, shuffle=config.shuffle_questions)
+
+ stream = DataStream(ds, iteration_scheme=it)
+
+ if config.concat_ctx_and_question:
+ stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>'])
+
+ # Sort sets of multiple batches to make batches of similar sizes
+ stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count))
+ comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context'))
+ stream = Mapping(stream, SortMapping(comparison))
+ stream = Unpack(stream)
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
+ stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32')
+
+ return ds, stream
+
+if __name__ == "__main__":
+ # Test
+ class DummyConfig:
+ def __init__(self):
+ self.shuffle_entities = True
+ self.shuffle_questions = False
+ self.concat_ctx_and_question = False
+ self.concat_question_before = False
+ self.batch_size = 2
+ self.sort_batch_count = 1000
+
+ ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"),
+ os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"),
+ DummyConfig())
+ it = stream.get_epoch_iterator()
+
+ for i, d in enumerate(stream.get_epoch_iterator()):
+ print '--'
+ print d
+ if i > 2: break
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/model/__init__.py
diff --git a/model/attentive_reader.py b/model/attentive_reader.py
new file mode 100644
index 0000000..682e48d
--- /dev/null
+++ b/model/attentive_reader.py
@@ -0,0 +1,152 @@
+import theano
+from theano import tensor
+import numpy
+
+from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
+from blocks.bricks.lookup import LookupTable
+from blocks.bricks.recurrent import LSTM
+
+from blocks.filter import VariableFilter
+from blocks.roles import WEIGHT
+from blocks.graph import ComputationGraph, apply_dropout, apply_noise
+
+def make_bidir_lstm_stack(seq, seq_dim, mask, sizes, skip=True, name=''):
+ bricks = []
+
+ curr_dim = [seq_dim]
+ curr_hidden = [seq]
+
+ hidden_list = []
+ for k, dim in enumerate(sizes):
+ fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_fwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)]
+ fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_fwd_lstm_%d'%(name,k))
+
+ bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_bwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)]
+ bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_bwd_lstm_%d'%(name,k))
+
+ bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins
+
+ fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden))
+ bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden))
+ fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=mask)
+ bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=mask[::-1])
+ hidden_list = hidden_list + [fwd_hidden, bwd_hidden]
+ if skip:
+ curr_hidden = [seq, fwd_hidden, bwd_hidden[::-1]]
+ curr_dim = [seq_dim, dim, dim]
+ else:
+ curr_hidden = [fwd_hidden, bwd_hidden[::-1]]
+ curr_dim = [dim, dim]
+
+ return bricks, hidden_list
+
+class Model():
+ def __init__(self, config, vocab_size):
+ question = tensor.imatrix('question')
+ question_mask = tensor.imatrix('question_mask')
+ context = tensor.imatrix('context')
+ context_mask = tensor.imatrix('context_mask')
+ answer = tensor.ivector('answer')
+ candidates = tensor.imatrix('candidates')
+ candidates_mask = tensor.imatrix('candidates_mask')
+
+ bricks = []
+
+ question = question.dimshuffle(1, 0)
+ question_mask = question_mask.dimshuffle(1, 0)
+ context = context.dimshuffle(1, 0)
+ context_mask = context_mask.dimshuffle(1, 0)
+
+ # Embed questions and cntext
+ embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
+ bricks.append(embed)
+
+ qembed = embed.apply(question)
+ cembed = embed.apply(context)
+
+ qlstms, qhidden_list = make_bidir_lstm_stack(qembed, config.embed_size, question_mask.astype(theano.config.floatX),
+ config.question_lstm_size, config.question_skip_connections, 'q')
+ clstms, chidden_list = make_bidir_lstm_stack(cembed, config.embed_size, context_mask.astype(theano.config.floatX),
+ config.ctx_lstm_size, config.ctx_skip_connections, 'ctx')
+ bricks = bricks + qlstms + clstms
+
+ # Calculate question encoding (concatenate layer1)
+ if config.question_skip_connections:
+ qenc_dim = 2*sum(config.question_lstm_size)
+ qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list], axis=1)
+ else:
+ qenc_dim = 2*config.question_lstm_size[-1]
+ qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list[-2:]], axis=1)
+ qenc.name = 'qenc'
+
+ # Calculate context encoding (concatenate layer1)
+ if config.ctx_skip_connections:
+ cenc_dim = 2*sum(config.ctx_lstm_size)
+ cenc = tensor.concatenate(chidden_list, axis=2)
+ else:
+ cenc_dim = 2*config.ctx_lstm_size[-1]
+ cenc = tensor.concatenate(chidden_list[-2:], axis=2)
+ cenc.name = 'cenc'
+
+ # Attention mechanism MLP
+ attention_mlp = MLP(dims=config.attention_mlp_hidden + [1],
+ activations=config.attention_mlp_activations[1:] + [Identity()],
+ name='attention_mlp')
+ attention_qlinear = Linear(input_dim=qenc_dim, output_dim=config.attention_mlp_hidden[0], name='attq')
+ attention_clinear = Linear(input_dim=cenc_dim, output_dim=config.attention_mlp_hidden[0], use_bias=False, name='attc')
+ bricks += [attention_mlp, attention_qlinear, attention_clinear]
+ layer1 = Tanh().apply(attention_clinear.apply(cenc.reshape((cenc.shape[0]*cenc.shape[1], cenc.shape[2])))
+ .reshape((cenc.shape[0],cenc.shape[1],config.attention_mlp_hidden[0]))
+ + attention_qlinear.apply(qenc)[None, :, :])
+ layer1.name = 'layer1'
+ att_weights = attention_mlp.apply(layer1.reshape((layer1.shape[0]*layer1.shape[1], layer1.shape[2])))
+ att_weights.name = 'att_weights_0'
+ att_weights = att_weights.reshape((layer1.shape[0], layer1.shape[1]))
+ att_weights.name = 'att_weights'
+
+ attended = tensor.sum(cenc * tensor.nnet.softmax(att_weights.T).T[:, :, None], axis=0)
+ attended.name = 'attended'
+
+ # Now we can calculate our output
+ out_mlp = MLP(dims=[cenc_dim + qenc_dim] + config.out_mlp_hidden + [config.n_entities],
+ activations=config.out_mlp_activations + [Identity()],
+ name='out_mlp')
+ bricks += [out_mlp]
+ probs = out_mlp.apply(tensor.concatenate([attended, qenc], axis=1))
+ probs.name = 'probs'
+
+ is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
+ tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
+ probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
+
+ # Calculate prediction, cost and error rate
+ pred = probs.argmax(axis=1)
+ cost = Softmax().categorical_cross_entropy(answer, probs).mean()
+ error_rate = tensor.neq(answer, pred).mean()
+
+ # Apply dropout
+ cg = ComputationGraph([cost, error_rate])
+ if config.w_noise > 0:
+ noise_vars = VariableFilter(roles=[WEIGHT])(cg)
+ cg = apply_noise(cg, noise_vars, config.w_noise)
+ if config.dropout > 0:
+ cg = apply_dropout(cg, qhidden_list + chidden_list, config.dropout)
+ [cost_reg, error_rate_reg] = cg.outputs
+
+ # Other stuff
+ cost_reg.name = cost.name = 'cost'
+ error_rate_reg.name = error_rate.name = 'error_rate'
+
+ self.sgd_cost = cost_reg
+ self.monitor_vars = [[cost_reg], [error_rate_reg]]
+ self.monitor_vars_valid = [[cost], [error_rate]]
+
+ # Initialize bricks
+ for brick in bricks:
+ brick.weights_init = config.weights_init
+ brick.biases_init = config.biases_init
+ brick.initialize()
+
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/model/deep_bidir_lstm.py b/model/deep_bidir_lstm.py
new file mode 100644
index 0000000..1e000c6
--- /dev/null
+++ b/model/deep_bidir_lstm.py
@@ -0,0 +1,109 @@
+import theano
+from theano import tensor
+import numpy
+
+from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
+from blocks.bricks.lookup import LookupTable
+from blocks.bricks.recurrent import LSTM
+
+from blocks.filter import VariableFilter
+from blocks.roles import WEIGHT
+from blocks.graph import ComputationGraph, apply_dropout, apply_noise
+
+class Model():
+ def __init__(self, config, vocab_size):
+ question = tensor.imatrix('question')
+ question_mask = tensor.imatrix('question_mask')
+ answer = tensor.ivector('answer')
+ candidates = tensor.imatrix('candidates')
+ candidates_mask = tensor.imatrix('candidates_mask')
+
+ bricks = []
+
+
+ # set time as first dimension
+ question = question.dimshuffle(1, 0)
+ question_mask = question_mask.dimshuffle(1, 0)
+
+ # Embed questions
+ embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
+ bricks.append(embed)
+ qembed = embed.apply(question)
+
+ # Create and apply LSTM stack
+ curr_dim = [config.embed_size]
+ curr_hidden = [qembed]
+
+ hidden_list = []
+ for k, dim in enumerate(config.lstm_size):
+ fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='fwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)]
+ fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='fwd_lstm_%d'%k)
+
+ bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='bwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)]
+ bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='bwd_lstm_%d'%k)
+
+ bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins
+
+ fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden))
+ bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden))
+ fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=question_mask.astype(theano.config.floatX))
+ bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=question_mask.astype(theano.config.floatX)[::-1])
+ hidden_list = hidden_list + [fwd_hidden, bwd_hidden]
+ if config.skip_connections:
+ curr_hidden = [qembed, fwd_hidden, bwd_hidden[::-1]]
+ curr_dim = [config.embed_size, dim, dim]
+ else:
+ curr_hidden = [fwd_hidden, bwd_hidden[::-1]]
+ curr_dim = [dim, dim]
+
+ # Create and apply output MLP
+ if config.skip_connections:
+ out_mlp = MLP(dims=[2*sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities],
+ activations=config.out_mlp_activations + [Identity()],
+ name='out_mlp')
+ bricks.append(out_mlp)
+
+ probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1))
+ else:
+ out_mlp = MLP(dims=[2*config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities],
+ activations=config.out_mlp_activations + [Identity()],
+ name='out_mlp')
+ bricks.append(out_mlp)
+
+ probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list[-2:]], axis=1))
+
+ is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
+ tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
+ probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
+
+ # Calculate prediction, cost and error rate
+ pred = probs.argmax(axis=1)
+ cost = Softmax().categorical_cross_entropy(answer, probs).mean()
+ error_rate = tensor.neq(answer, pred).mean()
+
+ # Apply dropout
+ cg = ComputationGraph([cost, error_rate])
+ if config.w_noise > 0:
+ noise_vars = VariableFilter(roles=[WEIGHT])(cg)
+ cg = apply_noise(cg, noise_vars, config.w_noise)
+ if config.dropout > 0:
+ cg = apply_dropout(cg, hidden_list, config.dropout)
+ [cost_reg, error_rate_reg] = cg.outputs
+
+ # Other stuff
+ cost_reg.name = cost.name = 'cost'
+ error_rate_reg.name = error_rate.name = 'error_rate'
+
+ self.sgd_cost = cost_reg
+ self.monitor_vars = [[cost_reg], [error_rate_reg]]
+ self.monitor_vars_valid = [[cost], [error_rate]]
+
+ # Initialize bricks
+ for brick in bricks:
+ brick.weights_init = config.weights_init
+ brick.biases_init = config.biases_init
+ brick.initialize()
+
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/model/deep_lstm.py b/model/deep_lstm.py
new file mode 100644
index 0000000..02cc034
--- /dev/null
+++ b/model/deep_lstm.py
@@ -0,0 +1,99 @@
+import theano
+from theano import tensor
+import numpy
+
+from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
+from blocks.bricks.lookup import LookupTable
+from blocks.bricks.recurrent import LSTM
+
+from blocks.graph import ComputationGraph, apply_dropout
+
+
+class Model():
+ def __init__(self, config, vocab_size):
+ question = tensor.imatrix('question')
+ question_mask = tensor.imatrix('question_mask')
+ answer = tensor.ivector('answer')
+ candidates = tensor.imatrix('candidates')
+ candidates_mask = tensor.imatrix('candidates_mask')
+
+ bricks = []
+
+
+ # set time as first dimension
+ question = question.dimshuffle(1, 0)
+ question_mask = question_mask.dimshuffle(1, 0)
+
+ # Embed questions
+ embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
+ bricks.append(embed)
+ qembed = embed.apply(question)
+
+ # Create and apply LSTM stack
+ curr_dim = config.embed_size
+ curr_hidden = qembed
+
+ hidden_list = []
+ for k, dim in enumerate(config.lstm_size):
+ lstm_in = Linear(input_dim=curr_dim, output_dim=4*dim, name='lstm_in_%d'%k)
+ lstm = LSTM(dim=dim, activation=Tanh(), name='lstm_%d'%k)
+ bricks = bricks + [lstm_in, lstm]
+
+ tmp = lstm_in.apply(curr_hidden)
+ hidden, _ = lstm.apply(tmp, mask=question_mask.astype(theano.config.floatX))
+ hidden_list.append(hidden)
+ if config.skip_connections:
+ curr_hidden = tensor.concatenate([hidden, qembed], axis=2)
+ curr_dim = dim + config.embed_size
+ else:
+ curr_hidden = hidden
+ curr_dim = dim
+
+ # Create and apply output MLP
+ if config.skip_connections:
+ out_mlp = MLP(dims=[sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities],
+ activations=config.out_mlp_activations + [Identity()],
+ name='out_mlp')
+ bricks.append(out_mlp)
+
+ probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1))
+ else:
+ out_mlp = MLP(dims=[config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities],
+ activations=config.out_mlp_activations + [Identity()],
+ name='out_mlp')
+ bricks.append(out_mlp)
+
+ probs = out_mlp.apply(hidden_list[-1][-1,:,:])
+
+ is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
+ tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
+ probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
+
+ # Calculate prediction, cost and error rate
+ pred = probs.argmax(axis=1)
+ cost = Softmax().categorical_cross_entropy(answer, probs).mean()
+ error_rate = tensor.neq(answer, pred).mean()
+
+ # Apply dropout
+ cg = ComputationGraph([cost, error_rate])
+ if config.dropout > 0:
+ cg = apply_dropout(cg, hidden_list, config.dropout)
+ [cost_reg, error_rate_reg] = cg.outputs
+
+ # Other stuff
+ cost_reg.name = cost.name = 'cost'
+ error_rate_reg.name = error_rate.name = 'error_rate'
+
+ self.sgd_cost = cost_reg
+ self.monitor_vars = [[cost_reg], [error_rate_reg]]
+ self.monitor_vars_valid = [[cost], [error_rate]]
+
+ # Initialize bricks
+ for brick in bricks:
+ brick.weights_init = config.weights_init
+ brick.biases_init = config.biases_init
+ brick.initialize()
+
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/paramsaveload.py b/paramsaveload.py
new file mode 100644
index 0000000..9c05926
--- /dev/null
+++ b/paramsaveload.py
@@ -0,0 +1,37 @@
+import logging
+
+import numpy
+
+import cPickle
+
+from blocks.extensions import SimpleExtension
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger('extensions.SaveLoadParams')
+
+class SaveLoadParams(SimpleExtension):
+ def __init__(self, path, model, **kwargs):
+ super(SaveLoadParams, self).__init__(**kwargs)
+
+ self.path = path
+ self.model = model
+
+ def do_save(self):
+ with open(self.path, 'w') as f:
+ logger.info('Saving parameters to %s...'%self.path)
+ cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+
+ def do_load(self):
+ try:
+ with open(self.path, 'r') as f:
+ logger.info('Loading parameters from %s...'%self.path)
+ self.model.set_parameter_values(cPickle.load(f))
+ except IOError:
+ pass
+
+ def do(self, which_callback, *args):
+ if which_callback == 'before_training':
+ self.do_load()
+ else:
+ self.do_save()
+
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..84f9a4b
--- /dev/null
+++ b/train.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python
+
+import logging
+import numpy
+import sys
+import os
+import importlib
+
+import theano
+from theano import tensor
+
+from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar
+from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
+from blocks.graph import ComputationGraph
+from blocks.main_loop import MainLoop
+from blocks.model import Model
+from blocks.algorithms import GradientDescent
+
+try:
+ from blocks.extras.extensions.plot import Plot
+ plot_avail = True
+except ImportError:
+ plot_avail = False
+ print "No plotting extension available."
+
+import data
+from paramsaveload import SaveLoadParams
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+sys.setrecursionlimit(500000)
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
+ sys.exit(1)
+ model_name = sys.argv[1]
+ config = importlib.import_module('.%s' % model_name, 'config')
+
+ # Build datastream
+ path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training")
+ valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation")
+ vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt")
+
+ ds, train_stream = data.setup_datastream(path, vocab_path, config)
+ _, valid_stream = data.setup_datastream(valid_path, vocab_path, config)
+
+ dump_path = os.path.join("model_params", model_name+".pkl")
+
+ # Build model
+ m = config.Model(config, ds.vocab_size)
+
+ # Build the Blocks stuff for training
+ model = Model(m.sgd_cost)
+
+ algorithm = GradientDescent(cost=m.sgd_cost,
+ step_rule=config.step_rule,
+ parameters=model.parameters)
+
+ extensions = [
+ TrainingDataMonitoring(
+ [v for l in m.monitor_vars for v in l],
+ prefix='train',
+ every_n_batches=config.print_freq)
+ ]
+ if config.save_freq is not None and dump_path is not None:
+ extensions += [
+ SaveLoadParams(path=dump_path,
+ model=model,
+ before_training=True,
+ after_training=True,
+ after_epoch=True,
+ every_n_batches=config.save_freq)
+ ]
+ if valid_stream is not None and config.valid_freq != -1:
+ extensions += [
+ DataStreamMonitoring(
+ [v for l in m.monitor_vars_valid for v in l],
+ valid_stream,
+ prefix='valid',
+ every_n_batches=config.valid_freq),
+ ]
+ if plot_avail:
+ plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv]
+ for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)]
+ extensions += [
+ Plot(document='deepmind_qa_'+model_name,
+ channels=plot_channels,
+ # server_url='http://localhost:5006/', # If you need, change this
+ every_n_batches=config.print_freq)
+ ]
+ extensions += [
+ Printing(every_n_batches=config.print_freq,
+ after_epoch=True),
+ ProgressBar()
+ ]
+
+ main_loop = MainLoop(
+ model=model,
+ data_stream=train_stream,
+ algorithm=algorithm,
+ extensions=extensions
+ )
+
+ # Run the model !
+ main_loop.run()
+ main_loop.profile.report()
+
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :