From 13fc171f60ae1981c7ad4f2a302a8a85c29addc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Thu, 23 Jul 2015 21:20:32 -0400 Subject: Use new tvt dataset with option --tvt --- model/memory_network.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'model/memory_network.py') diff --git a/model/memory_network.py b/model/memory_network.py index 14d1e07..e7ba51c 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -53,20 +53,20 @@ class StreamBase(object): @property def valid_dataset(self): - return TaxiDataset(self.config.valid_set, 'valid.hdf5') + return TaxiDataset(data.valid_set, data.valid_ds) @property def valid_trips_ids(self): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) return valid.get_data(None, slice(0, valid.num_examples))[0] @property def train_dataset(self): - return TaxiDataset('train') + return TaxiDataset('train', data.traintest_ds) @property def test_dataset(self): - return TaxiDataset('test') + return TaxiDataset('test', data.traintest_ds) class StreamSimple(StreamBase): @@ -96,7 +96,8 @@ class StreamSimple(StreamBase): prefix_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) + if not data.tvt: + prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) -- cgit v1.2.3