From 0021c3fb99d1cd3f8792a8cf5c35548815536428 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 27 Jul 2015 12:59:39 -0400 Subject: Config files --- model/stream.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'model/stream.py') diff --git a/model/stream.py b/model/stream.py index 61ff1c3..d69b962 100644 --- a/model/stream.py +++ b/model/stream.py @@ -1,4 +1,4 @@ -from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing +from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing, Filter from fuel.streams import DataStream from fuel.schemes import ConstantScheme, ShuffledExampleScheme @@ -31,6 +31,12 @@ class StreamRec(object): elif not data.tvt: stream = transformers.add_destination(stream) + if hasattr(self.config, 'train_max_len'): + idx = stream.sources.index('latitude') + def max_len_filter(x): + return len(x[idx]) <= self.config.train_max_len + stream = Filter(stream, max_len_filter) + stream = transformers.TaxiExcludeEmptyTrips(stream) stream = transformers.taxi_add_datetime(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) -- cgit v1.2.3