aboutsummaryrefslogblamecommitdiff
path: root/data.py
blob: b3fa6d27306b54efc25b555a1d039948d1fd9bbb (plain) (tree)
















































































































































































                                                                                                                          
import logging
import random
import numpy

import cPickle

from picklable_itertools import iter_

from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.schemes import IterationScheme, ConstantScheme
from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer

import sys
import os

logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)

class QADataset(Dataset):
    def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs):
        self.provides_sources = ('context', 'question', 'answer', 'candidates')

        self.path = path

        self.vocab = ['@entity%d' % i for i in range(n_entities)] + \
                     [w.rstrip('\n') for w in open(vocab_file)] + \
                     ['<UNK>', '@placeholder'] + \
                     (['<SEP>'] if need_sep_token else [])

        self.n_entities = n_entities
        self.vocab_size = len(self.vocab)
        self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)}

        super(QADataset, self).__init__(**kwargs)

    def to_word_id(self, w, cand_mapping):
        if w in cand_mapping:
            return cand_mapping[w]
        elif w[:7] == '@entity':
            raise ValueError("Unmapped entity token: %s"%w)
        elif w in self.reverse_vocab:
            return self.reverse_vocab[w]
        else:
            return self.reverse_vocab['<UNK>']

    def to_word_ids(self, s, cand_mapping):
        return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32)

    def get_data(self, state=None, request=None):
        if request is None or state is not None:
            raise ValueError("Expected a request (name of a question file) and no state.")

        lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))]

        ctx = lines[2]
        q = lines[4]
        a = lines[6]
        cand = [s.split(':')[0] for s in lines[8:]]

        entities = range(self.n_entities)
        while len(cand) > len(entities):
            logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers"
                %(len(cand), request))
            entities = entities + entities
        random.shuffle(entities)
        cand_mapping = {t: k for t, k in zip(cand, entities)}

        ctx = self.to_word_ids(ctx, cand_mapping)
        q = self.to_word_ids(q, cand_mapping)
        cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32)
        a = numpy.int32(self.to_word_id(a, cand_mapping))

        if not a < self.n_entities:
            raise ValueError("Invalid answer token %d"%a)
        if not numpy.all(cand < self.n_entities):
            raise ValueError("Invalid candidate in list %s"%repr(cand))
        if not numpy.all(ctx < self.vocab_size):
            raise ValueError("Context word id out of bounds: %d"%int(ctx.max()))
        if not numpy.all(ctx >= 0):
            raise ValueError("Context word id negative: %d"%int(ctx.min()))
        if not numpy.all(q < self.vocab_size):
            raise ValueError("Question word id out of bounds: %d"%int(q.max()))
        if not numpy.all(q >= 0):
            raise ValueError("Question word id negative: %d"%int(q.min()))

        return (ctx, q, a, cand)

class QAIterator(IterationScheme):
    requests_examples = True
    def __init__(self, path, shuffle=False, **kwargs):
        self.path = path
        self.shuffle = shuffle

        super(QAIterator, self).__init__(**kwargs)
    
    def get_request_iterator(self):
        l = [f for f in os.listdir(self.path)
               if os.path.isfile(os.path.join(self.path, f))]
        if self.shuffle:
            random.shuffle(l)
        return iter_(l)

# -------------- DATASTREAM SETUP --------------------


class ConcatCtxAndQuestion(Transformer):
    produces_examples = True
    def __init__(self, stream, concat_question_before, separator_token=None, **kwargs):
        assert stream.sources == ('context', 'question', 'answer', 'candidates')
        self.sources = ('question', 'answer', 'candidates')

        self.sep = numpy.array([separator_token] if separator_token is not None else [],
                               dtype=numpy.int32)
        self.concat_question_before = concat_question_before

        super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs)

    def get_data(self, request=None):
        if request is not None:
            raise ValueError('Unsupported: request')

        ctx, q, a, cand = next(self.child_epoch_iterator)

        if self.concat_question_before:
            return (numpy.concatenate([q, self.sep, ctx]), a, cand)
        else:
            return (numpy.concatenate([ctx, self.sep, q]), a, cand)
        
class _balanced_batch_helper(object):
    def __init__(self, key):
        self.key = key
    def __call__(self, data):
        return data[self.key].shape[0]

def setup_datastream(path, vocab_file, config):
    ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question)
    it = QAIterator(path, shuffle=config.shuffle_questions)

    stream = DataStream(ds, iteration_scheme=it)

    if config.concat_ctx_and_question:
        stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>'])

    # Sort sets of multiple batches to make batches of similar sizes
    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32')

    return ds, stream

if __name__ == "__main__":
    # Test
    class DummyConfig:
        def __init__(self):
            self.shuffle_entities = True
            self.shuffle_questions = False
            self.concat_ctx_and_question = False
            self.concat_question_before = False
            self.batch_size = 2
            self.sort_batch_count = 1000

    ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"),
                                  os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"),
                                  DummyConfig())
    it = stream.get_epoch_iterator()

    for i, d in enumerate(stream.get_epoch_iterator()):
        print '--'
        print d
        if i > 2: break

# vim: set sts=4 ts=4 sw=4 tw=0 et :