import logging
import random
import numpy
import cPickle
from picklable_itertools import iter_
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme
from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer
import sys
import os
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
class DummyDataset(Dataset):
def __init__(self, nb_examples, rng_seed, min_out_len, max_out_len, **kwargs):
self.provides_sources = ('input', 'output')
random.seed(rng_seed)
table = [
[0, 1, 2, 3, 4],
[0, 1, 2, 1, 0],
[4, 3, 2, 3, 4],
[4, 3, 2, 1, 0]
]
prob0 = 0.7
prob = 0.2
self.data = []
for n in range(nb_examples):
o = []
i = []
l = random.randrange(min_out_len, max_out_len)
for p in range(l):
o.append(random.randrange(len(table)))
for x in table[o[-1]]:
q = 0
if random.uniform(0, 1) < prob0:
i.append(x)
while random.uniform(0, 1) < prob:
i.append(x)
self.data.append((i, o))
super(DummyDataset, self).__init__(**kwargs)
def get_data(self, state=None, request=None):
if request is None:
raise ValueError("Request required")
return self.data[request]
# -------------- DATASTREAM SETUP --------------------
def setup_datastream(batch_size, **kwargs):
ds = DummyDataset(**kwargs)
stream = DataStream(ds, iteration_scheme=SequentialExampleScheme(kwargs['nb_examples']))
stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
stream = Padding(stream, mask_sources=['input', 'output'])
return ds, stream
if __name__ == "__main__":
ds, stream = setup_datastream(nb_examples=5,
rng_seed=123,
min_out_len=3,
max_out_len=6)
for i, d in enumerate(stream.get_epoch_iterator()):
print '--'
print d
if i > 2: break
# vim: set sts=4 ts=4 sw=4 tw=0 et :