summaryrefslogtreecommitdiff
path: root/datastream.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-09 16:27:06 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-09 16:27:06 -0400
commitc5e1cd9c8c896096ad1630909a655b06eb398abb (patch)
tree836ef331e1e2ec96b6f634e53eb42e7e781319f0 /datastream.py
parent6a5ed2e43a5885eeb3c5e202ed5bb473f6065401 (diff)
downloadtext-rnn-c5e1cd9c8c896096ad1630909a655b06eb398abb.tar.gz
text-rnn-c5e1cd9c8c896096ad1630909a655b06eb398abb.zip
Now learning something
Diffstat (limited to 'datastream.py')
-rw-r--r--datastream.py42
1 files changed, 38 insertions, 4 deletions
diff --git a/datastream.py b/datastream.py
index 5d9441f..8025945 100644
--- a/datastream.py
+++ b/datastream.py
@@ -61,16 +61,50 @@ class BytesToIndices(Transformer):
data = next(self.child_epoch_iterator)
return numpy.array([ord(i) for i in data[0]], dtype='int16'),
-def setup_datastream(filename, seq_len, num_seqs_per_epoch=100):
+class ParallelSequences(Transformer):
+ def __init__(self, stream, num_seqs, seq_div_size, **kwargs):
+ self.sources = ('bytes',)
+
+ self.num_seqs = num_seqs
+ self.div_size = seq_div_size
+
+ self.tmp = None
+ self.i = 0
+
+ super(ParallelSequences, self).__init__(stream, **kwargs)
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError('Unsupported: request')
+
+ if self.tmp is None or self.i >= self.tmp.shape[1]:
+ self.tmp = numpy.concatenate([next(self.child_epoch_iterator)[0][None, :]
+ for _ in xrange(self.num_seqs)],
+ axis=0)
+ self.i = 0
+
+ ret = self.tmp[:, self.i:self.i + self.div_size]
+ self.i += self.div_size
+
+ return ret,
+
+
+
+def setup_datastream(filename, num_seqs, seq_len, seq_div_size):
ds = BinaryFileDataset(filename)
- it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs_per_epoch)
+ it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs)
stream = DataStream(ds, iteration_scheme=it)
stream = BytesToIndices(stream)
+ stream = ParallelSequences(stream, num_seqs, seq_div_size)
return stream
if __name__ == "__main__":
# Test
- stream = setup_datastream("data/logcompil.txt", 100)
- print(next(stream.get_epoch_iterator()))
+ stream = setup_datastream("data/logcompil.txt", 2, 60, 20)
+ it = stream.get_epoch_iterator()
+ for d, in stream.get_epoch_iterator():
+ print '--'
+ for u in range(d.shape[0]):
+ print ''.join(chr(i) for i in d[u])