aboutsummaryrefslogtreecommitdiff
path: root/ext_test.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-21 18:26:43 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-21 18:27:55 -0400
commite1673538607a7c8d784013b21b753f0c05c4cc34 (patch)
treef42e316e0c5bf67e3c9953aad6ba8fe9656829f2 /ext_test.py
parent58dcf7b17e9db6af53808994a7d39a759fcc5028 (diff)
downloadtaxi-e1673538607a7c8d784013b21b753f0c05c4cc34.tar.gz
taxi-e1673538607a7c8d784013b21b753f0c05c4cc34.zip
Genericize RNNs
Diffstat (limited to 'ext_test.py')
-rw-r--r--ext_test.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/ext_test.py b/ext_test.py
index 6a3fa0a..3af637b 100644
--- a/ext_test.py
+++ b/ext_test.py
@@ -29,12 +29,21 @@ class RunOnTest(SimpleExtension):
iter_no = repr(self.main_loop.log.status['iterations_done'])
if 'valid_destination_cost' in self.main_loop.log.current_row:
dvc = self.main_loop.log.current_row['valid_destination_cost']
- else:
+ elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
dvc = self.main_loop.log.current_row['valid_model_cost_cost']
+ elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
+ dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
+ else:
+ raise RuntimeError("Unknown model type")
+
if 'valid_time_cost' in self.main_loop.log.current_row:
tvc = self.main_loop.log.current_row['valid_time_cost']
- else:
+ elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
tvc = self.main_loop.log.current_row['valid_model_cost_cost']
+ elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
+ tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
+ else:
+ raise RuntimeError("Unknown model type")
if 'destination' in self.outputs:
dest_outname = 'test-dest-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, dvc)