aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
Diffstat (limited to 'model')
-rw-r--r--model/mlp.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/model/mlp.py b/model/mlp.py
index b1e9163..6abc86f 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -61,11 +61,11 @@ class Stream(object):
else:
stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream)
+ stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.Select(stream, tuple(req_vars))
stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
@@ -77,17 +77,17 @@ class Stream(object):
def valid(self, req_vars):
stream = TaxiStream(self.config.valid_set, 'valid.hdf5')
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.Select(stream, tuple(req_vars))
return Batch(stream, iteration_scheme=ConstantScheme(1000))
def test(self, req_vars):
stream = TaxiStream('test')
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
- stream = transformers.TaxiRemoveTestOnlyClients(stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
+ stream = transformers.taxi_remove_test_only_clients(stream)
return Batch(stream, iteration_scheme=ConstantScheme(1))