import logging import random import numpy import cPickle from picklable_itertools import iter_ from fuel.datasets import Dataset from fuel.streams import DataStream from fuel.schemes import IterationScheme, ConstantScheme from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer import sys import os logging.basicConfig(level='INFO') logger = logging.getLogger(__name__) class QADataset(Dataset): def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs): self.provides_sources = ('context', 'question', 'answer', 'candidates') self.path = path self.vocab = ['@entity%d' % i for i in range(n_entities)] + \ [w.rstrip('\n') for w in open(vocab_file)] + \ ['', '@placeholder'] + \ ([''] if need_sep_token else []) self.n_entities = n_entities self.vocab_size = len(self.vocab) self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)} super(QADataset, self).__init__(**kwargs) def to_word_id(self, w, cand_mapping): if w in cand_mapping: return cand_mapping[w] elif w[:7] == '@entity': raise ValueError("Unmapped entity token: %s"%w) elif w in self.reverse_vocab: return self.reverse_vocab[w] else: return self.reverse_vocab[''] def to_word_ids(self, s, cand_mapping): return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32) def get_data(self, state=None, request=None): if request is None or state is not None: raise ValueError("Expected a request (name of a question file) and no state.") lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))] ctx = lines[2] q = lines[4] a = lines[6] cand = [s.split(':')[0] for s in lines[8:]] entities = range(self.n_entities) while len(cand) > len(entities): logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers" %(len(cand), request)) entities = entities + entities random.shuffle(entities) cand_mapping = {t: k for t, k in zip(cand, entities)} ctx = self.to_word_ids(ctx, cand_mapping) q = self.to_word_ids(q, cand_mapping) cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32) a = numpy.int32(self.to_word_id(a, cand_mapping)) if not a < self.n_entities: raise ValueError("Invalid answer token %d"%a) if not numpy.all(cand < self.n_entities): raise ValueError("Invalid candidate in list %s"%repr(cand)) if not numpy.all(ctx < self.vocab_size): raise ValueError("Context word id out of bounds: %d"%int(ctx.max())) if not numpy.all(ctx >= 0): raise ValueError("Context word id negative: %d"%int(ctx.min())) if not numpy.all(q < self.vocab_size): raise ValueError("Question word id out of bounds: %d"%int(q.max())) if not numpy.all(q >= 0): raise ValueError("Question word id negative: %d"%int(q.min())) return (ctx, q, a, cand) class QAIterator(IterationScheme): requests_examples = True def __init__(self, path, shuffle=False, **kwargs): self.path = path self.shuffle = shuffle super(QAIterator, self).__init__(**kwargs) def get_request_iterator(self): l = [f for f in os.listdir(self.path) if os.path.isfile(os.path.join(self.path, f))] if self.shuffle: random.shuffle(l) return iter_(l) # -------------- DATASTREAM SETUP -------------------- class ConcatCtxAndQuestion(Transformer): produces_examples = True def __init__(self, stream, concat_question_before, separator_token=None, **kwargs): assert stream.sources == ('context', 'question', 'answer', 'candidates') self.sources = ('question', 'answer', 'candidates') self.sep = numpy.array([separator_token] if separator_token is not None else [], dtype=numpy.int32) self.concat_question_before = concat_question_before super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs) def get_data(self, request=None): if request is not None: raise ValueError('Unsupported: request') ctx, q, a, cand = next(self.child_epoch_iterator) if self.concat_question_before: return (numpy.concatenate([q, self.sep, ctx]), a, cand) else: return (numpy.concatenate([ctx, self.sep, q]), a, cand) class _balanced_batch_helper(object): def __init__(self, key): self.key = key def __call__(self, data): return data[self.key].shape[0] def setup_datastream(path, vocab_file, config): ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question) it = QAIterator(path, shuffle=config.shuffle_questions) stream = DataStream(ds, iteration_scheme=it) if config.concat_ctx_and_question: stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['']) # Sort sets of multiple batches to make batches of similar sizes stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count)) comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context')) stream = Mapping(stream, SortMapping(comparison)) stream = Unpack(stream) stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size)) stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32') return ds, stream if __name__ == "__main__": # Test class DummyConfig: def __init__(self): self.shuffle_entities = True self.shuffle_questions = False self.concat_ctx_and_question = False self.concat_question_before = False self.batch_size = 2 self.sort_batch_count = 1000 ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"), os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"), DummyConfig()) it = stream.get_epoch_iterator() for i, d in enumerate(stream.get_epoch_iterator()): print '--' print d if i > 2: break # vim: set sts=4 ts=4 sw=4 tw=0 et :