diff options
Diffstat (limited to 'ext_test.py')
-rw-r--r-- | ext_test.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ext_test.py b/ext_test.py index 6a3fa0a..3af637b 100644 --- a/ext_test.py +++ b/ext_test.py @@ -29,12 +29,21 @@ class RunOnTest(SimpleExtension): iter_no = repr(self.main_loop.log.status['iterations_done']) if 'valid_destination_cost' in self.main_loop.log.current_row: dvc = self.main_loop.log.current_row['valid_destination_cost'] - else: + elif 'valid_model_cost_cost' in self.main_loop.log.current_row: dvc = self.main_loop.log.current_row['valid_model_cost_cost'] + elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: + dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] + else: + raise RuntimeError("Unknown model type") + if 'valid_time_cost' in self.main_loop.log.current_row: tvc = self.main_loop.log.current_row['valid_time_cost'] - else: + elif 'valid_model_cost_cost' in self.main_loop.log.current_row: tvc = self.main_loop.log.current_row['valid_model_cost_cost'] + elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: + tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] + else: + raise RuntimeError("Unknown model type") if 'destination' in self.outputs: dest_outname = 'test-dest-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, dvc) |