aboutsummaryrefslogtreecommitdiff
path: root/model/stream.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-27 12:59:39 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-27 12:59:39 -0400
commit0021c3fb99d1cd3f8792a8cf5c35548815536428 (patch)
tree9edb909def7652579a1b6a40ecb311a61b40455b /model/stream.py
parent2a20bc827a8c1c9b6e74ef4e1234788207be45b8 (diff)
downloadtaxi-0021c3fb99d1cd3f8792a8cf5c35548815536428.tar.gz
taxi-0021c3fb99d1cd3f8792a8cf5c35548815536428.zip
Config files
Diffstat (limited to 'model/stream.py')
-rw-r--r--model/stream.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/model/stream.py b/model/stream.py
index 61ff1c3..d69b962 100644
--- a/model/stream.py
+++ b/model/stream.py
@@ -1,4 +1,4 @@
-from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing
+from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing, Filter
from fuel.streams import DataStream
from fuel.schemes import ConstantScheme, ShuffledExampleScheme
@@ -31,6 +31,12 @@ class StreamRec(object):
elif not data.tvt:
stream = transformers.add_destination(stream)
+ if hasattr(self.config, 'train_max_len'):
+ idx = stream.sources.index('latitude')
+ def max_len_filter(x):
+ return len(x[idx]) <= self.config.train_max_len
+ stream = Filter(stream, max_len_filter)
+
stream = transformers.TaxiExcludeEmptyTrips(stream)
stream = transformers.taxi_add_datetime(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))