summaryrefslogblamecommitdiff
path: root/datastream.py
blob: 8025945781845b7a4a959f994126e9112c1790eb (plain) (tree)






























































                                                                                    





























                                                                                         
                                    
                                                                  

                                                
                                                              




                          





                                                              
 
import logging
import random
import numpy

import cPickle

from picklable_itertools import iter_

from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.schemes import IterationScheme
from fuel.transformers import Transformer

import sys
import os

logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)


class BinaryFileDataset(Dataset):
    def __init__(self, filename, **kwargs):
        self.provides_sources= ('bytes',)

        self.f = open(filename, "rb")

        super(BinaryFileDataset, self).__init__(**kwargs)

    def get_data(self, state=None, request=None):
        if request is None:
            raise ValueError("Expected a request: begin, length")

        bg, ln = request
        self.f.seek(bg)
        return (self.f.read(ln),)

    def num_examples(self):
        return os.fstat(self.f.fileno()).st_size

class RandomBlockIterator(IterationScheme):
    def __init__(self, item_range, seq_len, num_seqs_per_epoch, **kwargs):
        self.seq_len = seq_len
        self.num_seqs = num_seqs_per_epoch
        self.item_range = item_range

        super(RandomBlockIterator, self).__init__(**kwargs)

    def get_request_iterator(self):
        l = [(random.randrange(0, self.item_range - self.seq_len + 1), self.seq_len)
             for _ in xrange(self.num_seqs)]
        return iter_(l)

class BytesToIndices(Transformer):
    def __init__(self, stream, **kwargs):
        self.sources = ('bytes',)
        super(BytesToIndices, self).__init__(stream, **kwargs)

    def get_data(self, request=None):
        if request is not None:
            raise ValueError('Unsupported: request')
        data = next(self.child_epoch_iterator)
        return numpy.array([ord(i) for i in data[0]], dtype='int16'),

class ParallelSequences(Transformer):
    def __init__(self, stream, num_seqs, seq_div_size, **kwargs):
        self.sources = ('bytes',)

        self.num_seqs = num_seqs
        self.div_size = seq_div_size

        self.tmp = None
        self.i = 0

        super(ParallelSequences, self).__init__(stream, **kwargs)

    def get_data(self, request=None):
        if request is not None:
            raise ValueError('Unsupported: request')

        if self.tmp is None or self.i >= self.tmp.shape[1]:
            self.tmp = numpy.concatenate([next(self.child_epoch_iterator)[0][None, :]
                                                         for _ in xrange(self.num_seqs)],
                                         axis=0)
            self.i = 0

        ret = self.tmp[:, self.i:self.i + self.div_size]
        self.i += self.div_size

        return ret,

            

def setup_datastream(filename, num_seqs, seq_len, seq_div_size):
    ds = BinaryFileDataset(filename)
    it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs)
    stream = DataStream(ds, iteration_scheme=it)
    stream = BytesToIndices(stream)
    stream = ParallelSequences(stream, num_seqs, seq_div_size)

    return stream

if __name__ == "__main__":
    # Test
    stream = setup_datastream("data/logcompil.txt", 2, 60, 20)
    it = stream.get_epoch_iterator()
    for d, in stream.get_epoch_iterator():
        print '--'
        for u in range(d.shape[0]):
            print ''.join(chr(i) for i in d[u])