summaryrefslogblamecommitdiff
path: root/datastream.py
blob: 5d9441fc633b752c11c38511d269d521c31aedcb (plain) (tree)











































































                                                                                    
import logging
import random
import numpy

import cPickle

from picklable_itertools import iter_

from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.schemes import IterationScheme
from fuel.transformers import Transformer

import sys
import os

logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)


class BinaryFileDataset(Dataset):
    def __init__(self, filename, **kwargs):
        self.provides_sources= ('bytes',)

        self.f = open(filename, "rb")

        super(BinaryFileDataset, self).__init__(**kwargs)

    def get_data(self, state=None, request=None):
        if request is None:
            raise ValueError("Expected a request: begin, length")

        bg, ln = request
        self.f.seek(bg)
        return (self.f.read(ln),)

    def num_examples(self):
        return os.fstat(self.f.fileno()).st_size

class RandomBlockIterator(IterationScheme):
    def __init__(self, item_range, seq_len, num_seqs_per_epoch, **kwargs):
        self.seq_len = seq_len
        self.num_seqs = num_seqs_per_epoch
        self.item_range = item_range

        super(RandomBlockIterator, self).__init__(**kwargs)

    def get_request_iterator(self):
        l = [(random.randrange(0, self.item_range - self.seq_len + 1), self.seq_len)
             for _ in xrange(self.num_seqs)]
        return iter_(l)

class BytesToIndices(Transformer):
    def __init__(self, stream, **kwargs):
        self.sources = ('bytes',)
        super(BytesToIndices, self).__init__(stream, **kwargs)

    def get_data(self, request=None):
        if request is not None:
            raise ValueError('Unsupported: request')
        data = next(self.child_epoch_iterator)
        return numpy.array([ord(i) for i in data[0]], dtype='int16'),

def setup_datastream(filename, seq_len, num_seqs_per_epoch=100):
    ds = BinaryFileDataset(filename)
    it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs_per_epoch)
    stream = DataStream(ds, iteration_scheme=it)
    stream = BytesToIndices(stream)

    return stream

if __name__ == "__main__":
    # Test
    stream = setup_datastream("data/logcompil.txt", 100)
    print(next(stream.get_epoch_iterator()))