From 6d946f29f7548c75e97f30c4356dbac200ee6cce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Mon, 18 May 2015 16:22:00 -0400 Subject: Refactor models, clean the code and separate training from testing. --- test.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100755 test.py (limited to 'test.py') diff --git a/test.py b/test.py new file mode 100755 index 0000000..bf39b15 --- /dev/null +++ b/test.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +import sys +import os +import importlib +import csv + +from blocks.dump import load_parameter_values +from blocks.model import Model + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, 'Usage: %s config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[1] + config = importlib.import_module('.%s' % model_name, 'config') + model_config = config.Model(config) + + stream = config.Stream(config) + inputs = stream.inputs() + outputs = model_config.predict.outputs + req_vars_test = model_config.predict.inputs + ['trip_id'] + test_stream = stream.test(req_vars_test) + + model = Model(model_config.predict(**inputs)) + parameters = load_parameter_values(os.path.join('model_data', model_name, 'params.npz')) + model.set_param_values(parameters) + + if 'destination' in outputs: + dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w") + dest_outcsv = csv.writer(dest_outfile) + dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) + if 'duration' in outputs: + time_outfile = open("output/test-time-output-%s.csv" % model_name, "w") + time_outcsv = csv.writer(time_outfile) + time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) + + function = model.get_theano_function() + for d in test_stream.get_epoch_iterator(as_dict=True): + input_values = [d[k.name] for k in model.inputs] + output_values = function(*input_values) + if 'destination' in outputs: + destination = output_values[outputs.index('destination')] + dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]]) + if 'duration' in outputs: + duration = output_values[outputs.index('duration')] + time_outcsv.writerow([d['trip_id'][0], duration[0]]) -- cgit v1.2.3