diff options
Diffstat (limited to 'model/memory_network.py')
-rw-r--r-- | model/memory_network.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/model/memory_network.py b/model/memory_network.py index 14d1e07..e7ba51c 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -53,20 +53,20 @@ class StreamBase(object): @property def valid_dataset(self): - return TaxiDataset(self.config.valid_set, 'valid.hdf5') + return TaxiDataset(data.valid_set, data.valid_ds) @property def valid_trips_ids(self): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) return valid.get_data(None, slice(0, valid.num_examples))[0] @property def train_dataset(self): - return TaxiDataset('train') + return TaxiDataset('train', data.traintest_ds) @property def test_dataset(self): - return TaxiDataset('test') + return TaxiDataset('test', data.traintest_ds) class StreamSimple(StreamBase): @@ -96,7 +96,8 @@ class StreamSimple(StreamBase): prefix_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) + if not data.tvt: + prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) |