import logging
import random
import numpy
import cPickle
from picklable_itertools import iter_
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.schemes import IterationScheme
from fuel.transformers import Transformer
import sys
import os
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
class BinaryFileDataset(Dataset):
def __init__(self, filename, **kwargs):
self.provides_sources= ('bytes',)
self.f = open(filename, "rb")
super(BinaryFileDataset, self).__init__(**kwargs)
def get_data(self, state=None, request=None):
if request is None:
raise ValueError("Expected a request: begin, length")
bg, ln = request
self.f.seek(bg)
return (self.f.read(ln),)
def num_examples(self):
return os.fstat(self.f.fileno()).st_size
class RandomBlockIterator(IterationScheme):
requests_examples=True
def __init__(self, item_range, seq_len, num_seqs_per_epoch, **kwargs):
self.seq_len = seq_len
self.num_seqs = num_seqs_per_epoch
self.item_range = item_range
super(RandomBlockIterator, self).__init__(**kwargs)
def get_request_iterator(self):
l = [(random.randrange(0, self.item_range - self.seq_len + 1), self.seq_len)
for _ in xrange(self.num_seqs)]
return iter_(l)
class BytesToIndices(Transformer):
def __init__(self, stream, **kwargs):
self.sources = ('bytes',)
super(BytesToIndices, self).__init__(stream, **kwargs)
def get_data(self, request=None):
if request is not None:
raise ValueError('Unsupported: request')
data = next(self.child_epoch_iterator)
return numpy.array([ord(i) for i in data[0]], dtype='int16'),
class ParallelSequences(Transformer):
def __init__(self, stream, num_seqs, seq_div_size, **kwargs):
self.sources = ('bytes',)
self.num_seqs = num_seqs
self.div_size = seq_div_size
self.tmp = None
self.i = 0
super(ParallelSequences, self).__init__(stream, **kwargs)
def get_data(self, request=None):
if request is not None:
raise ValueError('Unsupported: request')
if self.tmp is None or self.i >= self.tmp.shape[1]:
self.tmp = numpy.concatenate([next(self.child_epoch_iterator)[0][None, :]
for _ in xrange(self.num_seqs)],
axis=0)
self.i = 0
ret = self.tmp[:, self.i:self.i + self.div_size]
self.i += self.div_size
return ret,
def setup_datastream(filename, num_seqs, seq_len, seq_div_size):
ds = BinaryFileDataset(filename)
it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs)
stream = DataStream(ds, iteration_scheme=it)
stream = BytesToIndices(stream)
stream = ParallelSequences(stream, num_seqs, seq_div_size)
return stream
if __name__ == "__main__":
# Test
stream = setup_datastream("data/logcompil.txt", 2, 60, 20)
it = stream.get_epoch_iterator()
for d, in stream.get_epoch_iterator():
print '--'
for u in range(d.shape[0]):
print ''.join(chr(i) for i in d[u])