diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-25 17:31:17 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-25 17:31:17 -0400 |
commit | 2a20bc827a8c1c9b6e74ef4e1234788207be45b8 (patch) | |
tree | 2051f9f81324bbf202a3bc3fa8c59a68288d2566 /model/stream.py | |
parent | d3dbbe68642320b9a225b1d3515f0181916df8ad (diff) | |
download | taxi-2a20bc827a8c1c9b6e74ef4e1234788207be45b8.tar.gz taxi-2a20bc827a8c1c9b6e74ef4e1234788207be45b8.zip |
Add batch sorting
Diffstat (limited to 'model/stream.py')
-rw-r--r-- | model/stream.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/model/stream.py b/model/stream.py index 88b1d7f..61ff1c3 100644 --- a/model/stream.py +++ b/model/stream.py @@ -33,17 +33,13 @@ class StreamRec(object): stream = transformers.TaxiExcludeEmptyTrips(stream) stream = transformers.taxi_add_datetime(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - stream = MultiProcessing(stream) return stream @@ -54,9 +50,13 @@ class StreamRec(object): stream = transformers.taxi_add_datetime(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + stream = transformers.balanced_batch(stream, key='latitude', + batch_size=self.config.batch_size, + batch_sort_size=self.config.batch_sort_size) stream = Padding(stream, mask_sources=['latitude', 'longitude']) stream = transformers.Select(stream, req_vars) + stream = MultiProcessing(stream) + return stream def test(self, req_vars): |