diff options
author | Thomas Mesnard <thomas.mesnard@ens.fr> | 2016-03-01 00:27:15 +0100 |
---|---|---|
committer | Thomas Mesnard <thomas.mesnard@ens.fr> | 2016-03-02 09:28:39 +0100 |
commit | f31caf61be87850f3afcd367d6eb9521b2f613da (patch) | |
tree | 2bcceeb702ef0d35bfdc925977797c40290b6966 | |
download | deepmind-qa-f31caf61be87850f3afcd367d6eb9521b2f613da.tar.gz deepmind-qa-f31caf61be87850f3afcd367d6eb9521b2f613da.zip |
Initial commit
-rw-r--r-- | LICENSE | 21 | ||||
-rw-r--r-- | README.md | 21 | ||||
-rw-r--r-- | __init__.py | 0 | ||||
-rw-r--r-- | config/__init__.py | 0 | ||||
-rw-r--r-- | config/deep_bidir_lstm_2x128.py | 37 | ||||
-rw-r--r-- | config/deepmind_attentive_reader.py | 42 | ||||
-rw-r--r-- | config/deepmind_deep_lstm.py | 33 | ||||
-rw-r--r-- | data.py | 177 | ||||
-rw-r--r-- | model/__init__.py | 0 | ||||
-rw-r--r-- | model/attentive_reader.py | 152 | ||||
-rw-r--r-- | model/deep_bidir_lstm.py | 109 | ||||
-rw-r--r-- | model/deep_lstm.py | 99 | ||||
-rw-r--r-- | paramsaveload.py | 37 | ||||
-rwxr-xr-x | train.py | 112 |
14 files changed, 840 insertions, 0 deletions
@@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Thomas Mesnard + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..63f8142 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +DeepMind : Teaching Machines to Read and Comprehend +========================================= + +This repository contains an implementation of the two models (the Deep LSTM and the Attentive Reader) described in *Teaching Machines to Read and Comprehend* by Karl Moritz Hermann and al., NIPS, 2015. This repository also contains an implementation of a Deep Bidirectional LSTM. + +Models are implemented using [Theano](https://github.com/Theano/Theano) and [Blocks](https://github.com/mila-udem/blocks). Datasets are implemented using [Fuel](https://github.com/mila-udem/fuel). + +The corresponding dataset is provided by [DeepMind](https://github.com/deepmind/rc-data) but if the script does not work you can check [http://cs.nyu.edu/~kcho/DMQA/](http://cs.nyu.edu/~kcho/DMQA/) by [Kyunghyun Cho](http://www.kyunghyuncho.me/). + +Reference +========= +[Teaching Machines to Read and Comprehend](https://papers.nips.cc/paper/5945-teaching-machines-to-read-and-comprehend.pdf), by Karl Moritz Hermann, Tomáš Kočiský, Edward Grefenstette, Lasse Espeholt, Will Kay, Mustafa Suleyman and Phil Blunsom, Neural Information Processing Systems, 2015. + + +Credits +======= +[Thomas Mesnard](https://github.com/thomasmesnard) + +[Alex Auvolat](https://github.com/Alexis211) + +[Étienne Simon](https://github.com/ejls)
\ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/__init__.py diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/config/__init__.py diff --git a/config/deep_bidir_lstm_2x128.py b/config/deep_bidir_lstm_2x128.py new file mode 100644 index 0000000..f07f43f --- /dev/null +++ b/config/deep_bidir_lstm_2x128.py @@ -0,0 +1,37 @@ +from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping +from blocks.initialization import IsotropicGaussian, Constant +from blocks.bricks import Tanh + +from model.deep_bidir_lstm import Model + + +batch_size = 32 +sort_batch_count = 20 + +shuffle_questions = True +shuffle_entities = True + +concat_ctx_and_question = True +concat_question_before = True ## should not matter for bidirectionnal network + +embed_size = 200 + +lstm_size = [128, 128] +skip_connections = True + +n_entities = 550 +out_mlp_hidden = [] +out_mlp_activations = [] + +step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=5e-5), + BasicMomentum(momentum=0.9)]) + +dropout = 0.1 +w_noise = 0.05 + +valid_freq = 1000 +save_freq = 1000 +print_freq = 100 + +weights_init = IsotropicGaussian(0.01) +biases_init = Constant(0.) diff --git a/config/deepmind_attentive_reader.py b/config/deepmind_attentive_reader.py new file mode 100644 index 0000000..84a6cf0 --- /dev/null +++ b/config/deepmind_attentive_reader.py @@ -0,0 +1,42 @@ +from blocks.bricks import Tanh +from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping, Momentum +from blocks.initialization import IsotropicGaussian, Constant + +from model.attentive_reader import Model + + +batch_size = 32 +sort_batch_count = 20 + +shuffle_questions = True + +concat_ctx_and_question = False + +n_entities = 550 +embed_size = 200 + +ctx_lstm_size = [256] +ctx_skip_connections = True + +question_lstm_size = [256] +question_skip_connections = True + +attention_mlp_hidden = [100] +attention_mlp_activations = [Tanh()] + +out_mlp_hidden = [] +out_mlp_activations = [] + +step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=5e-5), + BasicMomentum(momentum=0.9)]) + +dropout = 0.2 +w_noise = 0. + +valid_freq = 1000 +save_freq = 1000 +print_freq = 100 + +weights_init = IsotropicGaussian(0.01) +biases_init = Constant(0.) + diff --git a/config/deepmind_deep_lstm.py b/config/deepmind_deep_lstm.py new file mode 100644 index 0000000..10b5c9b --- /dev/null +++ b/config/deepmind_deep_lstm.py @@ -0,0 +1,33 @@ +from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping +from blocks.initialization import IsotropicGaussian, Constant + +from model.deep_lstm import Model + + +batch_size = 32 +sort_batch_count = 20 + +shuffle_questions = True + +concat_ctx_and_question = True +concat_question_before = True + +embed_size = 200 + +lstm_size = [256, 256] +skip_connections = True + +out_mlp_hidden = [] +out_mlp_activations = [] + +step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=1e-4), + BasicMomentum(momentum=0.9)]) + +dropout = 0.1 + +valid_freq = 1000 +save_freq = 1000 +print_freq = 100 + +weights_init = IsotropicGaussian(0.01) +biases_init = Constant(0.) @@ -0,0 +1,177 @@ +import logging +import random +import numpy + +import cPickle + +from picklable_itertools import iter_ + +from fuel.datasets import Dataset +from fuel.streams import DataStream +from fuel.schemes import IterationScheme, ConstantScheme +from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer + +import sys +import os + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +class QADataset(Dataset): + def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs): + self.provides_sources = ('context', 'question', 'answer', 'candidates') + + self.path = path + + self.vocab = ['@entity%d' % i for i in range(n_entities)] + \ + [w.rstrip('\n') for w in open(vocab_file)] + \ + ['<UNK>', '@placeholder'] + \ + (['<SEP>'] if need_sep_token else []) + + self.n_entities = n_entities + self.vocab_size = len(self.vocab) + self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)} + + super(QADataset, self).__init__(**kwargs) + + def to_word_id(self, w, cand_mapping): + if w in cand_mapping: + return cand_mapping[w] + elif w[:7] == '@entity': + raise ValueError("Unmapped entity token: %s"%w) + elif w in self.reverse_vocab: + return self.reverse_vocab[w] + else: + return self.reverse_vocab['<UNK>'] + + def to_word_ids(self, s, cand_mapping): + return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32) + + def get_data(self, state=None, request=None): + if request is None or state is not None: + raise ValueError("Expected a request (name of a question file) and no state.") + + lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))] + + ctx = lines[2] + q = lines[4] + a = lines[6] + cand = [s.split(':')[0] for s in lines[8:]] + + entities = range(self.n_entities) + while len(cand) > len(entities): + logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers" + %(len(cand), request)) + entities = entities + entities + random.shuffle(entities) + cand_mapping = {t: k for t, k in zip(cand, entities)} + + ctx = self.to_word_ids(ctx, cand_mapping) + q = self.to_word_ids(q, cand_mapping) + cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32) + a = numpy.int32(self.to_word_id(a, cand_mapping)) + + if not a < self.n_entities: + raise ValueError("Invalid answer token %d"%a) + if not numpy.all(cand < self.n_entities): + raise ValueError("Invalid candidate in list %s"%repr(cand)) + if not numpy.all(ctx < self.vocab_size): + raise ValueError("Context word id out of bounds: %d"%int(ctx.max())) + if not numpy.all(ctx >= 0): + raise ValueError("Context word id negative: %d"%int(ctx.min())) + if not numpy.all(q < self.vocab_size): + raise ValueError("Question word id out of bounds: %d"%int(q.max())) + if not numpy.all(q >= 0): + raise ValueError("Question word id negative: %d"%int(q.min())) + + return (ctx, q, a, cand) + +class QAIterator(IterationScheme): + requests_examples = True + def __init__(self, path, shuffle=False, **kwargs): + self.path = path + self.shuffle = shuffle + + super(QAIterator, self).__init__(**kwargs) + + def get_request_iterator(self): + l = [f for f in os.listdir(self.path) + if os.path.isfile(os.path.join(self.path, f))] + if self.shuffle: + random.shuffle(l) + return iter_(l) + +# -------------- DATASTREAM SETUP -------------------- + + +class ConcatCtxAndQuestion(Transformer): + produces_examples = True + def __init__(self, stream, concat_question_before, separator_token=None, **kwargs): + assert stream.sources == ('context', 'question', 'answer', 'candidates') + self.sources = ('question', 'answer', 'candidates') + + self.sep = numpy.array([separator_token] if separator_token is not None else [], + dtype=numpy.int32) + self.concat_question_before = concat_question_before + + super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs) + + def get_data(self, request=None): + if request is not None: + raise ValueError('Unsupported: request') + + ctx, q, a, cand = next(self.child_epoch_iterator) + + if self.concat_question_before: + return (numpy.concatenate([q, self.sep, ctx]), a, cand) + else: + return (numpy.concatenate([ctx, self.sep, q]), a, cand) + +class _balanced_batch_helper(object): + def __init__(self, key): + self.key = key + def __call__(self, data): + return data[self.key].shape[0] + +def setup_datastream(path, vocab_file, config): + ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question) + it = QAIterator(path, shuffle=config.shuffle_questions) + + stream = DataStream(ds, iteration_scheme=it) + + if config.concat_ctx_and_question: + stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>']) + + # Sort sets of multiple batches to make batches of similar sizes + stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) + comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context')) + stream = Mapping(stream, SortMapping(comparison)) + stream = Unpack(stream) + + stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) + stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32') + + return ds, stream + +if __name__ == "__main__": + # Test + class DummyConfig: + def __init__(self): + self.shuffle_entities = True + self.shuffle_questions = False + self.concat_ctx_and_question = False + self.concat_question_before = False + self.batch_size = 2 + self.sort_batch_count = 1000 + + ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"), + os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"), + DummyConfig()) + it = stream.get_epoch_iterator() + + for i, d in enumerate(stream.get_epoch_iterator()): + print '--' + print d + if i > 2: break + +# vim: set sts=4 ts=4 sw=4 tw=0 et : diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/model/__init__.py diff --git a/model/attentive_reader.py b/model/attentive_reader.py new file mode 100644 index 0000000..682e48d --- /dev/null +++ b/model/attentive_reader.py @@ -0,0 +1,152 @@ +import theano +from theano import tensor +import numpy + +from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier +from blocks.bricks.lookup import LookupTable +from blocks.bricks.recurrent import LSTM + +from blocks.filter import VariableFilter +from blocks.roles import WEIGHT +from blocks.graph import ComputationGraph, apply_dropout, apply_noise + +def make_bidir_lstm_stack(seq, seq_dim, mask, sizes, skip=True, name=''): + bricks = [] + + curr_dim = [seq_dim] + curr_hidden = [seq] + + hidden_list = [] + for k, dim in enumerate(sizes): + fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_fwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)] + fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_fwd_lstm_%d'%(name,k)) + + bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_bwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)] + bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_bwd_lstm_%d'%(name,k)) + + bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins + + fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden)) + bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden)) + fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=mask) + bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=mask[::-1]) + hidden_list = hidden_list + [fwd_hidden, bwd_hidden] + if skip: + curr_hidden = [seq, fwd_hidden, bwd_hidden[::-1]] + curr_dim = [seq_dim, dim, dim] + else: + curr_hidden = [fwd_hidden, bwd_hidden[::-1]] + curr_dim = [dim, dim] + + return bricks, hidden_list + +class Model(): + def __init__(self, config, vocab_size): + question = tensor.imatrix('question') + question_mask = tensor.imatrix('question_mask') + context = tensor.imatrix('context') + context_mask = tensor.imatrix('context_mask') + answer = tensor.ivector('answer') + candidates = tensor.imatrix('candidates') + candidates_mask = tensor.imatrix('candidates_mask') + + bricks = [] + + question = question.dimshuffle(1, 0) + question_mask = question_mask.dimshuffle(1, 0) + context = context.dimshuffle(1, 0) + context_mask = context_mask.dimshuffle(1, 0) + + # Embed questions and cntext + embed = LookupTable(vocab_size, config.embed_size, name='question_embed') + bricks.append(embed) + + qembed = embed.apply(question) + cembed = embed.apply(context) + + qlstms, qhidden_list = make_bidir_lstm_stack(qembed, config.embed_size, question_mask.astype(theano.config.floatX), + config.question_lstm_size, config.question_skip_connections, 'q') + clstms, chidden_list = make_bidir_lstm_stack(cembed, config.embed_size, context_mask.astype(theano.config.floatX), + config.ctx_lstm_size, config.ctx_skip_connections, 'ctx') + bricks = bricks + qlstms + clstms + + # Calculate question encoding (concatenate layer1) + if config.question_skip_connections: + qenc_dim = 2*sum(config.question_lstm_size) + qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list], axis=1) + else: + qenc_dim = 2*config.question_lstm_size[-1] + qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list[-2:]], axis=1) + qenc.name = 'qenc' + + # Calculate context encoding (concatenate layer1) + if config.ctx_skip_connections: + cenc_dim = 2*sum(config.ctx_lstm_size) + cenc = tensor.concatenate(chidden_list, axis=2) + else: + cenc_dim = 2*config.ctx_lstm_size[-1] + cenc = tensor.concatenate(chidden_list[-2:], axis=2) + cenc.name = 'cenc' + + # Attention mechanism MLP + attention_mlp = MLP(dims=config.attention_mlp_hidden + [1], + activations=config.attention_mlp_activations[1:] + [Identity()], + name='attention_mlp') + attention_qlinear = Linear(input_dim=qenc_dim, output_dim=config.attention_mlp_hidden[0], name='attq') + attention_clinear = Linear(input_dim=cenc_dim, output_dim=config.attention_mlp_hidden[0], use_bias=False, name='attc') + bricks += [attention_mlp, attention_qlinear, attention_clinear] + layer1 = Tanh().apply(attention_clinear.apply(cenc.reshape((cenc.shape[0]*cenc.shape[1], cenc.shape[2]))) + .reshape((cenc.shape[0],cenc.shape[1],config.attention_mlp_hidden[0])) + + attention_qlinear.apply(qenc)[None, :, :]) + layer1.name = 'layer1' + att_weights = attention_mlp.apply(layer1.reshape((layer1.shape[0]*layer1.shape[1], layer1.shape[2]))) + att_weights.name = 'att_weights_0' + att_weights = att_weights.reshape((layer1.shape[0], layer1.shape[1])) + att_weights.name = 'att_weights' + + attended = tensor.sum(cenc * tensor.nnet.softmax(att_weights.T).T[:, :, None], axis=0) + attended.name = 'attended' + + # Now we can calculate our output + out_mlp = MLP(dims=[cenc_dim + qenc_dim] + config.out_mlp_hidden + [config.n_entities], + activations=config.out_mlp_activations + [Identity()], + name='out_mlp') + bricks += [out_mlp] + probs = out_mlp.apply(tensor.concatenate([attended, qenc], axis=1)) + probs.name = 'probs' + + is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :], + tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1) + probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs)) + + # Calculate prediction, cost and error rate + pred = probs.argmax(axis=1) + cost = Softmax().categorical_cross_entropy(answer, probs).mean() + error_rate = tensor.neq(answer, pred).mean() + + # Apply dropout + cg = ComputationGraph([cost, error_rate]) + if config.w_noise > 0: + noise_vars = VariableFilter(roles=[WEIGHT])(cg) + cg = apply_noise(cg, noise_vars, config.w_noise) + if config.dropout > 0: + cg = apply_dropout(cg, qhidden_list + chidden_list, config.dropout) + [cost_reg, error_rate_reg] = cg.outputs + + # Other stuff + cost_reg.name = cost.name = 'cost' + error_rate_reg.name = error_rate.name = 'error_rate' + + self.sgd_cost = cost_reg + self.monitor_vars = [[cost_reg], [error_rate_reg]] + self.monitor_vars_valid = [[cost], [error_rate]] + + # Initialize bricks + for brick in bricks: + brick.weights_init = config.weights_init + brick.biases_init = config.biases_init + brick.initialize() + + + +# vim: set sts=4 ts=4 sw=4 tw=0 et : diff --git a/model/deep_bidir_lstm.py b/model/deep_bidir_lstm.py new file mode 100644 index 0000000..1e000c6 --- /dev/null +++ b/model/deep_bidir_lstm.py @@ -0,0 +1,109 @@ +import theano +from theano import tensor +import numpy + +from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier +from blocks.bricks.lookup import LookupTable +from blocks.bricks.recurrent import LSTM + +from blocks.filter import VariableFilter +from blocks.roles import WEIGHT +from blocks.graph import ComputationGraph, apply_dropout, apply_noise + +class Model(): + def __init__(self, config, vocab_size): + question = tensor.imatrix('question') + question_mask = tensor.imatrix('question_mask') + answer = tensor.ivector('answer') + candidates = tensor.imatrix('candidates') + candidates_mask = tensor.imatrix('candidates_mask') + + bricks = [] + + + # set time as first dimension + question = question.dimshuffle(1, 0) + question_mask = question_mask.dimshuffle(1, 0) + + # Embed questions + embed = LookupTable(vocab_size, config.embed_size, name='question_embed') + bricks.append(embed) + qembed = embed.apply(question) + + # Create and apply LSTM stack + curr_dim = [config.embed_size] + curr_hidden = [qembed] + + hidden_list = [] + for k, dim in enumerate(config.lstm_size): + fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='fwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)] + fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='fwd_lstm_%d'%k) + + bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='bwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)] + bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='bwd_lstm_%d'%k) + + bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins + + fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden)) + bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden)) + fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=question_mask.astype(theano.config.floatX)) + bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=question_mask.astype(theano.config.floatX)[::-1]) + hidden_list = hidden_list + [fwd_hidden, bwd_hidden] + if config.skip_connections: + curr_hidden = [qembed, fwd_hidden, bwd_hidden[::-1]] + curr_dim = [config.embed_size, dim, dim] + else: + curr_hidden = [fwd_hidden, bwd_hidden[::-1]] + curr_dim = [dim, dim] + + # Create and apply output MLP + if config.skip_connections: + out_mlp = MLP(dims=[2*sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities], + activations=config.out_mlp_activations + [Identity()], + name='out_mlp') + bricks.append(out_mlp) + + probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1)) + else: + out_mlp = MLP(dims=[2*config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities], + activations=config.out_mlp_activations + [Identity()], + name='out_mlp') + bricks.append(out_mlp) + + probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list[-2:]], axis=1)) + + is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :], + tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1) + probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs)) + + # Calculate prediction, cost and error rate + pred = probs.argmax(axis=1) + cost = Softmax().categorical_cross_entropy(answer, probs).mean() + error_rate = tensor.neq(answer, pred).mean() + + # Apply dropout + cg = ComputationGraph([cost, error_rate]) + if config.w_noise > 0: + noise_vars = VariableFilter(roles=[WEIGHT])(cg) + cg = apply_noise(cg, noise_vars, config.w_noise) + if config.dropout > 0: + cg = apply_dropout(cg, hidden_list, config.dropout) + [cost_reg, error_rate_reg] = cg.outputs + + # Other stuff + cost_reg.name = cost.name = 'cost' + error_rate_reg.name = error_rate.name = 'error_rate' + + self.sgd_cost = cost_reg + self.monitor_vars = [[cost_reg], [error_rate_reg]] + self.monitor_vars_valid = [[cost], [error_rate]] + + # Initialize bricks + for brick in bricks: + brick.weights_init = config.weights_init + brick.biases_init = config.biases_init + brick.initialize() + + + +# vim: set sts=4 ts=4 sw=4 tw=0 et : diff --git a/model/deep_lstm.py b/model/deep_lstm.py new file mode 100644 index 0000000..02cc034 --- /dev/null +++ b/model/deep_lstm.py @@ -0,0 +1,99 @@ +import theano +from theano import tensor +import numpy + +from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier +from blocks.bricks.lookup import LookupTable +from blocks.bricks.recurrent import LSTM + +from blocks.graph import ComputationGraph, apply_dropout + + +class Model(): + def __init__(self, config, vocab_size): + question = tensor.imatrix('question') + question_mask = tensor.imatrix('question_mask') + answer = tensor.ivector('answer') + candidates = tensor.imatrix('candidates') + candidates_mask = tensor.imatrix('candidates_mask') + + bricks = [] + + + # set time as first dimension + question = question.dimshuffle(1, 0) + question_mask = question_mask.dimshuffle(1, 0) + + # Embed questions + embed = LookupTable(vocab_size, config.embed_size, name='question_embed') + bricks.append(embed) + qembed = embed.apply(question) + + # Create and apply LSTM stack + curr_dim = config.embed_size + curr_hidden = qembed + + hidden_list = [] + for k, dim in enumerate(config.lstm_size): + lstm_in = Linear(input_dim=curr_dim, output_dim=4*dim, name='lstm_in_%d'%k) + lstm = LSTM(dim=dim, activation=Tanh(), name='lstm_%d'%k) + bricks = bricks + [lstm_in, lstm] + + tmp = lstm_in.apply(curr_hidden) + hidden, _ = lstm.apply(tmp, mask=question_mask.astype(theano.config.floatX)) + hidden_list.append(hidden) + if config.skip_connections: + curr_hidden = tensor.concatenate([hidden, qembed], axis=2) + curr_dim = dim + config.embed_size + else: + curr_hidden = hidden + curr_dim = dim + + # Create and apply output MLP + if config.skip_connections: + out_mlp = MLP(dims=[sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities], + activations=config.out_mlp_activations + [Identity()], + name='out_mlp') + bricks.append(out_mlp) + + probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1)) + else: + out_mlp = MLP(dims=[config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities], + activations=config.out_mlp_activations + [Identity()], + name='out_mlp') + bricks.append(out_mlp) + + probs = out_mlp.apply(hidden_list[-1][-1,:,:]) + + is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :], + tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1) + probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs)) + + # Calculate prediction, cost and error rate + pred = probs.argmax(axis=1) + cost = Softmax().categorical_cross_entropy(answer, probs).mean() + error_rate = tensor.neq(answer, pred).mean() + + # Apply dropout + cg = ComputationGraph([cost, error_rate]) + if config.dropout > 0: + cg = apply_dropout(cg, hidden_list, config.dropout) + [cost_reg, error_rate_reg] = cg.outputs + + # Other stuff + cost_reg.name = cost.name = 'cost' + error_rate_reg.name = error_rate.name = 'error_rate' + + self.sgd_cost = cost_reg + self.monitor_vars = [[cost_reg], [error_rate_reg]] + self.monitor_vars_valid = [[cost], [error_rate]] + + # Initialize bricks + for brick in bricks: + brick.weights_init = config.weights_init + brick.biases_init = config.biases_init + brick.initialize() + + + +# vim: set sts=4 ts=4 sw=4 tw=0 et : diff --git a/paramsaveload.py b/paramsaveload.py new file mode 100644 index 0000000..9c05926 --- /dev/null +++ b/paramsaveload.py @@ -0,0 +1,37 @@ +import logging + +import numpy + +import cPickle + +from blocks.extensions import SimpleExtension + +logging.basicConfig(level='INFO') +logger = logging.getLogger('extensions.SaveLoadParams') + +class SaveLoadParams(SimpleExtension): + def __init__(self, path, model, **kwargs): + super(SaveLoadParams, self).__init__(**kwargs) + + self.path = path + self.model = model + + def do_save(self): + with open(self.path, 'w') as f: + logger.info('Saving parameters to %s...'%self.path) + cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + + def do_load(self): + try: + with open(self.path, 'r') as f: + logger.info('Loading parameters from %s...'%self.path) + self.model.set_parameter_values(cPickle.load(f)) + except IOError: + pass + + def do(self, which_callback, *args): + if which_callback == 'before_training': + self.do_load() + else: + self.do_save() + diff --git a/train.py b/train.py new file mode 100755 index 0000000..84f9a4b --- /dev/null +++ b/train.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +import logging +import numpy +import sys +import os +import importlib + +import theano +from theano import tensor + +from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar +from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring +from blocks.graph import ComputationGraph +from blocks.main_loop import MainLoop +from blocks.model import Model +from blocks.algorithms import GradientDescent + +try: + from blocks.extras.extensions.plot import Plot + plot_avail = True +except ImportError: + plot_avail = False + print "No plotting extension available." + +import data +from paramsaveload import SaveLoadParams + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +sys.setrecursionlimit(500000) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, 'Usage: %s config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[1] + config = importlib.import_module('.%s' % model_name, 'config') + + # Build datastream + path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training") + valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation") + vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt") + + ds, train_stream = data.setup_datastream(path, vocab_path, config) + _, valid_stream = data.setup_datastream(valid_path, vocab_path, config) + + dump_path = os.path.join("model_params", model_name+".pkl") + + # Build model + m = config.Model(config, ds.vocab_size) + + # Build the Blocks stuff for training + model = Model(m.sgd_cost) + + algorithm = GradientDescent(cost=m.sgd_cost, + step_rule=config.step_rule, + parameters=model.parameters) + + extensions = [ + TrainingDataMonitoring( + [v for l in m.monitor_vars for v in l], + prefix='train', + every_n_batches=config.print_freq) + ] + if config.save_freq is not None and dump_path is not None: + extensions += [ + SaveLoadParams(path=dump_path, + model=model, + before_training=True, + after_training=True, + after_epoch=True, + every_n_batches=config.save_freq) + ] + if valid_stream is not None and config.valid_freq != -1: + extensions += [ + DataStreamMonitoring( + [v for l in m.monitor_vars_valid for v in l], + valid_stream, + prefix='valid', + every_n_batches=config.valid_freq), + ] + if plot_avail: + plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv] + for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)] + extensions += [ + Plot(document='deepmind_qa_'+model_name, + channels=plot_channels, + # server_url='http://localhost:5006/', # If you need, change this + every_n_batches=config.print_freq) + ] + extensions += [ + Printing(every_n_batches=config.print_freq, + after_epoch=True), + ProgressBar() + ] + + main_loop = MainLoop( + model=model, + data_stream=train_stream, + algorithm=algorithm, + extensions=extensions + ) + + # Run the model ! + main_loop.run() + main_loop.profile.report() + + + +# vim: set sts=4 ts=4 sw=4 tw=0 et : |