diff options
Diffstat (limited to 'model/mlp.py')
-rw-r--r-- | model/mlp.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/model/mlp.py b/model/mlp.py index b1e9163..6abc86f 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -61,11 +61,11 @@ class Stream(object): else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) @@ -77,17 +77,17 @@ class Stream(object): def valid(self, req_vars): stream = TaxiStream(self.config.valid_set, 'valid.hdf5') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) return Batch(stream, iteration_scheme=ConstantScheme(1000)) def test(self, req_vars): stream = TaxiStream('test') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) - stream = transformers.TaxiRemoveTestOnlyClients(stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) + stream = transformers.taxi_remove_test_only_clients(stream) return Batch(stream, iteration_scheme=ConstantScheme(1)) |