diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-21 18:26:43 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-21 18:27:55 -0400 |
commit | e1673538607a7c8d784013b21b753f0c05c4cc34 (patch) | |
tree | f42e316e0c5bf67e3c9953aad6ba8fe9656829f2 /ext_test.py | |
parent | 58dcf7b17e9db6af53808994a7d39a759fcc5028 (diff) | |
download | taxi-e1673538607a7c8d784013b21b753f0c05c4cc34.tar.gz taxi-e1673538607a7c8d784013b21b753f0c05c4cc34.zip |
Genericize RNNs
Diffstat (limited to 'ext_test.py')
-rw-r--r-- | ext_test.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ext_test.py b/ext_test.py index 6a3fa0a..3af637b 100644 --- a/ext_test.py +++ b/ext_test.py @@ -29,12 +29,21 @@ class RunOnTest(SimpleExtension): iter_no = repr(self.main_loop.log.status['iterations_done']) if 'valid_destination_cost' in self.main_loop.log.current_row: dvc = self.main_loop.log.current_row['valid_destination_cost'] - else: + elif 'valid_model_cost_cost' in self.main_loop.log.current_row: dvc = self.main_loop.log.current_row['valid_model_cost_cost'] + elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: + dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] + else: + raise RuntimeError("Unknown model type") + if 'valid_time_cost' in self.main_loop.log.current_row: tvc = self.main_loop.log.current_row['valid_time_cost'] - else: + elif 'valid_model_cost_cost' in self.main_loop.log.current_row: tvc = self.main_loop.log.current_row['valid_model_cost_cost'] + elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: + tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] + else: + raise RuntimeError("Unknown model type") if 'destination' in self.outputs: dest_outname = 'test-dest-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, dvc) |