diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-23 21:20:32 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-23 21:20:32 -0400 |
commit | 13fc171f60ae1981c7ad4f2a302a8a85c29addc5 (patch) | |
tree | abc29e6a877a2f971b0be9715c112d8eee8b0eb4 /model/memory_network.py | |
parent | 8d31f9240056ec110cf63bde79d7661321d8ca7a (diff) | |
download | taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.tar.gz taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.zip |
Use new tvt dataset with option --tvt
Diffstat (limited to 'model/memory_network.py')
-rw-r--r-- | model/memory_network.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/model/memory_network.py b/model/memory_network.py index 14d1e07..e7ba51c 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -53,20 +53,20 @@ class StreamBase(object): @property def valid_dataset(self): - return TaxiDataset(self.config.valid_set, 'valid.hdf5') + return TaxiDataset(data.valid_set, data.valid_ds) @property def valid_trips_ids(self): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) return valid.get_data(None, slice(0, valid.num_examples))[0] @property def train_dataset(self): - return TaxiDataset('train') + return TaxiDataset('train', data.traintest_ds) @property def test_dataset(self): - return TaxiDataset('test') + return TaxiDataset('test', data.traintest_ds) class StreamSimple(StreamBase): @@ -96,7 +96,8 @@ class StreamSimple(StreamBase): prefix_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) + if not data.tvt: + prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) |