From e1673538607a7c8d784013b21b753f0c05c4cc34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Tue, 21 Jul 2015 18:26:43 -0400 Subject: Genericize RNNs --- ext_test.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'ext_test.py') diff --git a/ext_test.py b/ext_test.py index 6a3fa0a..3af637b 100644 --- a/ext_test.py +++ b/ext_test.py @@ -29,12 +29,21 @@ class RunOnTest(SimpleExtension): iter_no = repr(self.main_loop.log.status['iterations_done']) if 'valid_destination_cost' in self.main_loop.log.current_row: dvc = self.main_loop.log.current_row['valid_destination_cost'] - else: + elif 'valid_model_cost_cost' in self.main_loop.log.current_row: dvc = self.main_loop.log.current_row['valid_model_cost_cost'] + elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: + dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] + else: + raise RuntimeError("Unknown model type") + if 'valid_time_cost' in self.main_loop.log.current_row: tvc = self.main_loop.log.current_row['valid_time_cost'] - else: + elif 'valid_model_cost_cost' in self.main_loop.log.current_row: tvc = self.main_loop.log.current_row['valid_model_cost_cost'] + elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: + tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] + else: + raise RuntimeError("Unknown model type") if 'destination' in self.outputs: dest_outname = 'test-dest-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, dvc) -- cgit v1.2.3