aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-23 21:20:32 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-23 21:20:32 -0400
commit13fc171f60ae1981c7ad4f2a302a8a85c29addc5 (patch)
treeabc29e6a877a2f971b0be9715c112d8eee8b0eb4 /model/memory_network.py
parent8d31f9240056ec110cf63bde79d7661321d8ca7a (diff)
downloadtaxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.tar.gz
taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.zip
Use new tvt dataset with option --tvt
Diffstat (limited to 'model/memory_network.py')
-rw-r--r--model/memory_network.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/model/memory_network.py b/model/memory_network.py
index 14d1e07..e7ba51c 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -53,20 +53,20 @@ class StreamBase(object):
@property
def valid_dataset(self):
- return TaxiDataset(self.config.valid_set, 'valid.hdf5')
+ return TaxiDataset(data.valid_set, data.valid_ds)
@property
def valid_trips_ids(self):
- valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
+ valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
return valid.get_data(None, slice(0, valid.num_examples))[0]
@property
def train_dataset(self):
- return TaxiDataset('train')
+ return TaxiDataset('train', data.traintest_ds)
@property
def test_dataset(self):
- return TaxiDataset('test')
+ return TaxiDataset('test', data.traintest_ds)
class StreamSimple(StreamBase):
@@ -96,7 +96,8 @@ class StreamSimple(StreamBase):
prefix_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
- prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
+ if not data.tvt:
+ prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
max_splits=self.config.max_splits)