aboutsummaryrefslogtreecommitdiff
path: root/dummy_dataset.py
diff options
context:
space:
mode:
authorThomas Mesnard <thomas.mesnard@ens.fr>2015-12-23 20:27:49 +0100
committerThomas Mesnard <thomas.mesnard@ens.fr>2015-12-23 20:27:49 +0100
commitc9ba2abc7172b4657216e0fcc638098060d7f753 (patch)
tree47a89a2c6e16da4e9ce963092d280880b7338744 /dummy_dataset.py
parenta1394aad6fca2dd560eb45a9b2e4cbc7be4c2bf7 (diff)
downloadpgm-ctc-c9ba2abc7172b4657216e0fcc638098060d7f753.tar.gz
pgm-ctc-c9ba2abc7172b4657216e0fcc638098060d7f753.zip
At least it compiles
Diffstat (limited to 'dummy_dataset.py')
-rw-r--r--dummy_dataset.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/dummy_dataset.py b/dummy_dataset.py
new file mode 100644
index 0000000..a65a946
--- /dev/null
+++ b/dummy_dataset.py
@@ -0,0 +1,85 @@
+
+import logging
+import random
+import numpy
+
+import cPickle
+
+from picklable_itertools import iter_
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme
+from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer
+
+import sys
+import os
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+class DummyDataset(Dataset):
+ def __init__(self, nb_examples, rng_seed, min_out_len, max_out_len, **kwargs):
+ self.provides_sources = ('input', 'output')
+
+ random.seed(rng_seed)
+
+ table = [
+ [0, 1, 2, 3, 4],
+ [0, 1, 2, 1, 0],
+ [4, 3, 2, 3, 4],
+ [4, 3, 2, 1, 0]
+ ]
+ prob0 = 0.7
+ prob = 0.2
+
+ self.data = []
+ for n in range(nb_examples):
+ o = []
+ i = []
+ l = random.randrange(min_out_len, max_out_len)
+ for p in range(l):
+ o.append(random.randrange(len(table)))
+ for x in table[o[-1]]:
+ q = 0
+ if random.uniform(0, 1) < prob0:
+ i.append(x)
+ while random.uniform(0, 1) < prob:
+ i.append(x)
+ self.data.append((i, o))
+
+ super(DummyDataset, self).__init__(**kwargs)
+
+
+ def get_data(self, state=None, request=None):
+ if request is None:
+ raise ValueError("Request required")
+
+ return self.data[request]
+
+# -------------- DATASTREAM SETUP --------------------
+
+def setup_datastream(batch_size, **kwargs):
+ ds = DummyDataset(**kwargs)
+ stream = DataStream(ds, iteration_scheme=SequentialExampleScheme(kwargs['nb_examples']))
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
+ stream = Padding(stream, mask_sources=['input', 'output'])
+
+ return ds, stream
+
+if __name__ == "__main__":
+
+ ds, stream = setup_datastream(nb_examples=5,
+ rng_seed=123,
+ min_out_len=3,
+ max_out_len=6)
+
+ for i, d in enumerate(stream.get_epoch_iterator()):
+ print '--'
+ print d
+
+
+ if i > 2: break
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :