diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-06-11 16:26:29 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-06-11 16:26:29 -0400 |
commit | b44f7113c64568a20c9f93ca17577f17d7695dcb (patch) | |
tree | 82f82c2b8bd20e8d8015e86078ba05cf15c43d06 /model/mlp.py | |
parent | e6215fdd8b64c91210268cc8e929b19c22a53660 (diff) | |
download | taxi-b44f7113c64568a20c9f93ca17577f17d7695dcb.tar.gz taxi-b44f7113c64568a20c9f93ca17577f17d7695dcb.zip |
Use Mapping instead of extending Transformer
Diffstat (limited to 'model/mlp.py')
-rw-r--r-- | model/mlp.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/model/mlp.py b/model/mlp.py index b1e9163..6abc86f 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -61,11 +61,11 @@ class Stream(object): else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) @@ -77,17 +77,17 @@ class Stream(object): def valid(self, req_vars): stream = TaxiStream(self.config.valid_set, 'valid.hdf5') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) return Batch(stream, iteration_scheme=ConstantScheme(1000)) def test(self, req_vars): stream = TaxiStream('test') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) - stream = transformers.TaxiRemoveTestOnlyClients(stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) + stream = transformers.taxi_remove_test_only_clients(stream) return Batch(stream, iteration_scheme=ConstantScheme(1)) |