aboutsummaryrefslogtreecommitdiff
path: root/model/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/stream.py')
-rw-r--r--model/stream.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/model/stream.py b/model/stream.py
index 88b1d7f..61ff1c3 100644
--- a/model/stream.py
+++ b/model/stream.py
@@ -33,17 +33,13 @@ class StreamRec(object):
stream = transformers.TaxiExcludeEmptyTrips(stream)
stream = transformers.taxi_add_datetime(stream)
-
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
stream = transformers.balanced_batch(stream, key='latitude',
batch_size=self.config.batch_size,
batch_sort_size=self.config.batch_sort_size)
-
stream = Padding(stream, mask_sources=['latitude', 'longitude'])
-
stream = transformers.Select(stream, req_vars)
-
stream = MultiProcessing(stream)
return stream
@@ -54,9 +50,13 @@ class StreamRec(object):
stream = transformers.taxi_add_datetime(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+ stream = transformers.balanced_batch(stream, key='latitude',
+ batch_size=self.config.batch_size,
+ batch_sort_size=self.config.batch_sort_size)
stream = Padding(stream, mask_sources=['latitude', 'longitude'])
stream = transformers.Select(stream, req_vars)
+ stream = MultiProcessing(stream)
+
return stream
def test(self, req_vars):