summaryrefslogtreecommitdiff
path: root/datastream.py
diff options
context:
space:
mode:
Diffstat (limited to 'datastream.py')
-rw-r--r--datastream.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/datastream.py b/datastream.py
new file mode 100644
index 0000000..5d9441f
--- /dev/null
+++ b/datastream.py
@@ -0,0 +1,76 @@
+import logging
+import random
+import numpy
+
+import cPickle
+
+from picklable_itertools import iter_
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.schemes import IterationScheme
+from fuel.transformers import Transformer
+
+import sys
+import os
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+
+class BinaryFileDataset(Dataset):
+ def __init__(self, filename, **kwargs):
+ self.provides_sources= ('bytes',)
+
+ self.f = open(filename, "rb")
+
+ super(BinaryFileDataset, self).__init__(**kwargs)
+
+ def get_data(self, state=None, request=None):
+ if request is None:
+ raise ValueError("Expected a request: begin, length")
+
+ bg, ln = request
+ self.f.seek(bg)
+ return (self.f.read(ln),)
+
+ def num_examples(self):
+ return os.fstat(self.f.fileno()).st_size
+
+class RandomBlockIterator(IterationScheme):
+ def __init__(self, item_range, seq_len, num_seqs_per_epoch, **kwargs):
+ self.seq_len = seq_len
+ self.num_seqs = num_seqs_per_epoch
+ self.item_range = item_range
+
+ super(RandomBlockIterator, self).__init__(**kwargs)
+
+ def get_request_iterator(self):
+ l = [(random.randrange(0, self.item_range - self.seq_len + 1), self.seq_len)
+ for _ in xrange(self.num_seqs)]
+ return iter_(l)
+
+class BytesToIndices(Transformer):
+ def __init__(self, stream, **kwargs):
+ self.sources = ('bytes',)
+ super(BytesToIndices, self).__init__(stream, **kwargs)
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError('Unsupported: request')
+ data = next(self.child_epoch_iterator)
+ return numpy.array([ord(i) for i in data[0]], dtype='int16'),
+
+def setup_datastream(filename, seq_len, num_seqs_per_epoch=100):
+ ds = BinaryFileDataset(filename)
+ it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs_per_epoch)
+ stream = DataStream(ds, iteration_scheme=it)
+ stream = BytesToIndices(stream)
+
+ return stream
+
+if __name__ == "__main__":
+ # Test
+ stream = setup_datastream("data/logcompil.txt", 100)
+ print(next(stream.get_epoch_iterator()))
+