From b44f7113c64568a20c9f93ca17577f17d7695dcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Thu, 11 Jun 2015 16:26:29 -0400 Subject: Use Mapping instead of extending Transformer --- model/mlp.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'model') diff --git a/model/mlp.py b/model/mlp.py index b1e9163..6abc86f 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -61,11 +61,11 @@ class Stream(object): else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) @@ -77,17 +77,17 @@ class Stream(object): def valid(self, req_vars): stream = TaxiStream(self.config.valid_set, 'valid.hdf5') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) return Batch(stream, iteration_scheme=ConstantScheme(1000)) def test(self, req_vars): stream = TaxiStream('test') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) - stream = transformers.TaxiRemoveTestOnlyClients(stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) + stream = transformers.taxi_remove_test_only_clients(stream) return Batch(stream, iteration_scheme=ConstantScheme(1)) -- cgit v1.2.3