aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ext_test.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/ext_test.py b/ext_test.py
index d121b61..092f39c 100644
--- a/ext_test.py
+++ b/ext_test.py
@@ -25,6 +25,9 @@ class RunOnTest(SimpleExtension):
self.function = cg.get_theano_function()
+ self.best_dvc = None
+ self.best_tvc = None
+
def do(self, which_callback, *args):
iter_no = self.main_loop.log.status['iterations_done']
if 'valid_destination_cost' in self.main_loop.log.current_row:
@@ -45,13 +48,19 @@ class RunOnTest(SimpleExtension):
else:
raise RuntimeError("Unknown model type")
- if 'destination' in self.outputs:
+ output_dvc = (self.best_dvc is None or dvc < self.best_dvc) and 'destination' in self.outputs
+ output_tvc = (self.best_tvc is None or tvc < self.best_tvc) and 'duration' in self.outputs
+
+ if not output_dvc and not output_tvc:
+ return
+
+ if output_dvc:
dest_outname = 'test-dest-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, dvc)
dest_outfile = open(os.path.join('output', dest_outname), 'w')
dest_outcsv = csv.writer(dest_outfile)
dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
logger.info("Generating output for test set: %s" % dest_outname)
- if 'duration' in self.outputs:
+ if output_tvc:
time_outname = 'test-time-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, tvc)
time_outfile = open(os.path.join('output', time_outname), 'w')
time_outcsv = csv.writer(time_outfile)
@@ -62,15 +71,15 @@ class RunOnTest(SimpleExtension):
input_values = [d[k.name] for k in self.inputs]
output_values = self.function(*input_values)
for i in range(d['trip_id'].shape[0]):
- if 'destination' in self.outputs:
+ if output_dvc:
destination = output_values[self.outputs.index('destination')]
dest_outcsv.writerow([d['trip_id'][i], destination[i, 0], destination[i, 1]])
- if 'duration' in self.outputs:
+ if output_tvc:
duration = output_values[self.outputs.index('duration')]
time_outcsv.writerow([d['trip_id'][i], int(round(duration[i]))])
- if 'destination' in self.outputs:
+ if output_dvc:
dest_outfile.close()
- if 'duration' in self.outputs:
+ if output_tvc:
time_outfile.close()