aboutsummaryrefslogblamecommitdiff
path: root/dummy_dataset.py
blob: 5053c5a63100aaa3ce7aaa9feca4784afb450228 (plain) (tree)



















































































                                                                                            
import logging
import random
import numpy

import cPickle

from picklable_itertools import iter_

from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme
from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer

import sys
import os

logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)

class DummyDataset(Dataset):
    def __init__(self, nb_examples, rng_seed, min_out_len, max_out_len, **kwargs):
        self.provides_sources = ('input', 'output')

        random.seed(rng_seed)

        table = [
                [0, 1, 2, 3, 4],
                [0, 1, 2, 1, 0],
                [4, 3, 2, 3, 4],
                [4, 3, 2, 1, 0]
        ]
        prob0 = 0.7
        prob = 0.2

        self.data = []
        for n in range(nb_examples):
            o = []
            i = []
            l = random.randrange(min_out_len, max_out_len)
            for p in range(l):
                o.append(random.randrange(len(table)))
                for x in table[o[-1]]:
                    q = 0
                    if random.uniform(0, 1) < prob0:
                        i.append(x)
                    while random.uniform(0, 1) < prob:
                        i.append(x)
            self.data.append((i, o))

        super(DummyDataset, self).__init__(**kwargs)


    def get_data(self, state=None, request=None):
        if request is None:
            raise ValueError("Request required")

        return self.data[request]

# -------------- DATASTREAM SETUP --------------------

def setup_datastream(batch_size, **kwargs):
    ds = DummyDataset(**kwargs)
    stream = DataStream(ds, iteration_scheme=SequentialExampleScheme(kwargs['nb_examples']))

    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size))
    stream = Padding(stream, mask_sources=['input', 'output'])

    return ds, stream

if __name__ == "__main__":

    ds, stream = setup_datastream(nb_examples=5,
                                  rng_seed=123,
                                  min_out_len=3,
                                  max_out_len=6)

    for i, d in enumerate(stream.get_epoch_iterator()):
        print '--'
        print d


        if i > 2: break

# vim: set sts=4 ts=4 sw=4 tw=0 et :