diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-23 21:20:32 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-23 21:20:32 -0400 |
commit | 13fc171f60ae1981c7ad4f2a302a8a85c29addc5 (patch) | |
tree | abc29e6a877a2f971b0be9715c112d8eee8b0eb4 /model/rnn.py | |
parent | 8d31f9240056ec110cf63bde79d7661321d8ca7a (diff) | |
download | taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.tar.gz taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.zip |
Use new tvt dataset with option --tvt
Diffstat (limited to 'model/rnn.py')
-rw-r--r-- | model/rnn.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/model/rnn.py b/model/rnn.py index 70d054a..7bdba3b 100644 --- a/model/rnn.py +++ b/model/rnn.py @@ -146,15 +146,15 @@ class Stream(object): self.config = config def train(self, req_vars): - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - - stream = TaxiDataset('train') + stream = TaxiDataset('train', data.traintest_ds) stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + if not data.tvt: + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) + valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + stream = transformers.TaxiExcludeEmptyTrips(stream) stream = transformers.taxi_add_datetime(stream) - stream = transformers.add_destination(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) @@ -163,7 +163,7 @@ class Stream(object): return stream def valid(self, req_vars): - stream = TaxiStream(self.config.valid_set, 'valid.hdf5') + stream = TaxiStream(data.valid_set, data.valid_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) @@ -173,7 +173,7 @@ class Stream(object): return stream def test(self, req_vars): - stream = TaxiStream('test') + stream = TaxiStream('test', data.traintest_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_remove_test_only_clients(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) |