From c9ba2abc7172b4657216e0fcc638098060d7f753 Mon Sep 17 00:00:00 2001 From: Thomas Mesnard Date: Wed, 23 Dec 2015 20:27:49 +0100 Subject: At least it compiles --- dummy_dataset.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 dummy_dataset.py (limited to 'dummy_dataset.py') diff --git a/dummy_dataset.py b/dummy_dataset.py new file mode 100644 index 0000000..a65a946 --- /dev/null +++ b/dummy_dataset.py @@ -0,0 +1,85 @@ + +import logging +import random +import numpy + +import cPickle + +from picklable_itertools import iter_ + +from fuel.datasets import Dataset +from fuel.streams import DataStream +from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme +from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer + +import sys +import os + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +class DummyDataset(Dataset): + def __init__(self, nb_examples, rng_seed, min_out_len, max_out_len, **kwargs): + self.provides_sources = ('input', 'output') + + random.seed(rng_seed) + + table = [ + [0, 1, 2, 3, 4], + [0, 1, 2, 1, 0], + [4, 3, 2, 3, 4], + [4, 3, 2, 1, 0] + ] + prob0 = 0.7 + prob = 0.2 + + self.data = [] + for n in range(nb_examples): + o = [] + i = [] + l = random.randrange(min_out_len, max_out_len) + for p in range(l): + o.append(random.randrange(len(table))) + for x in table[o[-1]]: + q = 0 + if random.uniform(0, 1) < prob0: + i.append(x) + while random.uniform(0, 1) < prob: + i.append(x) + self.data.append((i, o)) + + super(DummyDataset, self).__init__(**kwargs) + + + def get_data(self, state=None, request=None): + if request is None: + raise ValueError("Request required") + + return self.data[request] + +# -------------- DATASTREAM SETUP -------------------- + +def setup_datastream(batch_size, **kwargs): + ds = DummyDataset(**kwargs) + stream = DataStream(ds, iteration_scheme=SequentialExampleScheme(kwargs['nb_examples'])) + + stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size)) + stream = Padding(stream, mask_sources=['input', 'output']) + + return ds, stream + +if __name__ == "__main__": + + ds, stream = setup_datastream(nb_examples=5, + rng_seed=123, + min_out_len=3, + max_out_len=6) + + for i, d in enumerate(stream.get_epoch_iterator()): + print '--' + print d + + + if i > 2: break + +# vim: set sts=4 ts=4 sw=4 tw=0 et : -- cgit v1.2.3