aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-06-11 16:26:29 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-06-11 16:26:29 -0400
commitb44f7113c64568a20c9f93ca17577f17d7695dcb (patch)
tree82f82c2b8bd20e8d8015e86078ba05cf15c43d06 /model
parente6215fdd8b64c91210268cc8e929b19c22a53660 (diff)
downloadtaxi-b44f7113c64568a20c9f93ca17577f17d7695dcb.tar.gz
taxi-b44f7113c64568a20c9f93ca17577f17d7695dcb.zip
Use Mapping instead of extending Transformer
Diffstat (limited to 'model')
-rw-r--r--model/mlp.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/model/mlp.py b/model/mlp.py
index b1e9163..6abc86f 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -61,11 +61,11 @@ class Stream(object):
else:
stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream)
+ stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.Select(stream, tuple(req_vars))
stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
@@ -77,17 +77,17 @@ class Stream(object):
def valid(self, req_vars):
stream = TaxiStream(self.config.valid_set, 'valid.hdf5')
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.Select(stream, tuple(req_vars))
return Batch(stream, iteration_scheme=ConstantScheme(1000))
def test(self, req_vars):
stream = TaxiStream('test')
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
- stream = transformers.TaxiRemoveTestOnlyClients(stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
+ stream = transformers.taxi_remove_test_only_clients(stream)
return Batch(stream, iteration_scheme=ConstantScheme(1))