diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-09 16:27:06 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-09 16:27:06 -0400 |
commit | c5e1cd9c8c896096ad1630909a655b06eb398abb (patch) | |
tree | 836ef331e1e2ec96b6f634e53eb42e7e781319f0 /datastream.py | |
parent | 6a5ed2e43a5885eeb3c5e202ed5bb473f6065401 (diff) | |
download | text-rnn-c5e1cd9c8c896096ad1630909a655b06eb398abb.tar.gz text-rnn-c5e1cd9c8c896096ad1630909a655b06eb398abb.zip |
Now learning something
Diffstat (limited to 'datastream.py')
-rw-r--r-- | datastream.py | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/datastream.py b/datastream.py index 5d9441f..8025945 100644 --- a/datastream.py +++ b/datastream.py @@ -61,16 +61,50 @@ class BytesToIndices(Transformer): data = next(self.child_epoch_iterator) return numpy.array([ord(i) for i in data[0]], dtype='int16'), -def setup_datastream(filename, seq_len, num_seqs_per_epoch=100): +class ParallelSequences(Transformer): + def __init__(self, stream, num_seqs, seq_div_size, **kwargs): + self.sources = ('bytes',) + + self.num_seqs = num_seqs + self.div_size = seq_div_size + + self.tmp = None + self.i = 0 + + super(ParallelSequences, self).__init__(stream, **kwargs) + + def get_data(self, request=None): + if request is not None: + raise ValueError('Unsupported: request') + + if self.tmp is None or self.i >= self.tmp.shape[1]: + self.tmp = numpy.concatenate([next(self.child_epoch_iterator)[0][None, :] + for _ in xrange(self.num_seqs)], + axis=0) + self.i = 0 + + ret = self.tmp[:, self.i:self.i + self.div_size] + self.i += self.div_size + + return ret, + + + +def setup_datastream(filename, num_seqs, seq_len, seq_div_size): ds = BinaryFileDataset(filename) - it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs_per_epoch) + it = RandomBlockIterator(ds.num_examples(), seq_len, num_seqs) stream = DataStream(ds, iteration_scheme=it) stream = BytesToIndices(stream) + stream = ParallelSequences(stream, num_seqs, seq_div_size) return stream if __name__ == "__main__": # Test - stream = setup_datastream("data/logcompil.txt", 100) - print(next(stream.get_epoch_iterator())) + stream = setup_datastream("data/logcompil.txt", 2, 60, 20) + it = stream.get_epoch_iterator() + for d, in stream.get_epoch_iterator(): + print '--' + for u in range(d.shape[0]): + print ''.join(chr(i) for i in d[u]) |