From 8d31f9240056ec110cf63bde79d7661321d8ca7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Thu, 23 Jul 2015 21:19:53 -0400 Subject: Run test only when validation score improve --- ext_test.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'ext_test.py') diff --git a/ext_test.py b/ext_test.py index d121b61..092f39c 100644 --- a/ext_test.py +++ b/ext_test.py @@ -25,6 +25,9 @@ class RunOnTest(SimpleExtension): self.function = cg.get_theano_function() + self.best_dvc = None + self.best_tvc = None + def do(self, which_callback, *args): iter_no = self.main_loop.log.status['iterations_done'] if 'valid_destination_cost' in self.main_loop.log.current_row: @@ -45,13 +48,19 @@ class RunOnTest(SimpleExtension): else: raise RuntimeError("Unknown model type") - if 'destination' in self.outputs: + output_dvc = (self.best_dvc is None or dvc < self.best_dvc) and 'destination' in self.outputs + output_tvc = (self.best_tvc is None or tvc < self.best_tvc) and 'duration' in self.outputs + + if not output_dvc and not output_tvc: + return + + if output_dvc: dest_outname = 'test-dest-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, dvc) dest_outfile = open(os.path.join('output', dest_outname), 'w') dest_outcsv = csv.writer(dest_outfile) dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) logger.info("Generating output for test set: %s" % dest_outname) - if 'duration' in self.outputs: + if output_tvc: time_outname = 'test-time-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, tvc) time_outfile = open(os.path.join('output', time_outname), 'w') time_outcsv = csv.writer(time_outfile) @@ -62,15 +71,15 @@ class RunOnTest(SimpleExtension): input_values = [d[k.name] for k in self.inputs] output_values = self.function(*input_values) for i in range(d['trip_id'].shape[0]): - if 'destination' in self.outputs: + if output_dvc: destination = output_values[self.outputs.index('destination')] dest_outcsv.writerow([d['trip_id'][i], destination[i, 0], destination[i, 1]]) - if 'duration' in self.outputs: + if output_tvc: duration = output_values[self.outputs.index('duration')] time_outcsv.writerow([d['trip_id'][i], int(round(duration[i]))]) - if 'destination' in self.outputs: + if output_dvc: dest_outfile.close() - if 'duration' in self.outputs: + if output_tvc: time_outfile.close() -- cgit v1.2.3