aboutsummaryrefslogtreecommitdiff
path: root/ext_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'ext_test.py')
-rw-r--r--ext_test.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/ext_test.py b/ext_test.py
index 6a3fa0a..3af637b 100644
--- a/ext_test.py
+++ b/ext_test.py
@@ -29,12 +29,21 @@ class RunOnTest(SimpleExtension):
iter_no = repr(self.main_loop.log.status['iterations_done'])
if 'valid_destination_cost' in self.main_loop.log.current_row:
dvc = self.main_loop.log.current_row['valid_destination_cost']
- else:
+ elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
dvc = self.main_loop.log.current_row['valid_model_cost_cost']
+ elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
+ dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
+ else:
+ raise RuntimeError("Unknown model type")
+
if 'valid_time_cost' in self.main_loop.log.current_row:
tvc = self.main_loop.log.current_row['valid_time_cost']
- else:
+ elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
tvc = self.main_loop.log.current_row['valid_model_cost_cost']
+ elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
+ tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
+ else:
+ raise RuntimeError("Unknown model type")
if 'destination' in self.outputs:
dest_outname = 'test-dest-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, dvc)