From 13fc171f60ae1981c7ad4f2a302a8a85c29addc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Thu, 23 Jul 2015 21:20:32 -0400 Subject: Use new tvt dataset with option --tvt --- model/mlp.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'model/mlp.py') diff --git a/model/mlp.py b/model/mlp.py index d24b2cc..0336be1 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -63,17 +63,18 @@ class Stream(object): self.config = config def train(self, req_vars): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - - stream = TaxiDataset('train') + stream = TaxiDataset('train', data.traintest_ds) if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + if not data.tvt: + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) + valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) if hasattr(self.config, 'shuffle_batch_size'): @@ -92,7 +93,7 @@ class Stream(object): return stream def valid(self, req_vars): - stream = TaxiStream(self.config.valid_set, 'valid.hdf5') + stream = TaxiStream(data.valid_set, data.valid_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) @@ -100,7 +101,7 @@ class Stream(object): return Batch(stream, iteration_scheme=ConstantScheme(1000)) def test(self, req_vars): - stream = TaxiStream('test') + stream = TaxiStream('test', data.traintest_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) -- cgit v1.2.3