diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
commit | 6d946f29f7548c75e97f30c4356dbac200ee6cce (patch) | |
tree | 387e586c7ad0c1a0167d21451c9a8c877cf3ef0e /test.py | |
parent | 1e6d08b0c9ac5983691b182631c71e9d46ee71cc (diff) | |
download | taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.tar.gz taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.zip |
Refactor models, clean the code and separate training from testing.
Diffstat (limited to 'test.py')
-rwxr-xr-x | test.py | 48 |
1 files changed, 48 insertions, 0 deletions
@@ -0,0 +1,48 @@ +#!/usr/bin/env python + +import sys +import os +import importlib +import csv + +from blocks.dump import load_parameter_values +from blocks.model import Model + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, 'Usage: %s config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[1] + config = importlib.import_module('.%s' % model_name, 'config') + model_config = config.Model(config) + + stream = config.Stream(config) + inputs = stream.inputs() + outputs = model_config.predict.outputs + req_vars_test = model_config.predict.inputs + ['trip_id'] + test_stream = stream.test(req_vars_test) + + model = Model(model_config.predict(**inputs)) + parameters = load_parameter_values(os.path.join('model_data', model_name, 'params.npz')) + model.set_param_values(parameters) + + if 'destination' in outputs: + dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w") + dest_outcsv = csv.writer(dest_outfile) + dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) + if 'duration' in outputs: + time_outfile = open("output/test-time-output-%s.csv" % model_name, "w") + time_outcsv = csv.writer(time_outfile) + time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) + + function = model.get_theano_function() + for d in test_stream.get_epoch_iterator(as_dict=True): + input_values = [d[k.name] for k in model.inputs] + output_values = function(*input_values) + if 'destination' in outputs: + destination = output_values[outputs.index('destination')] + dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]]) + if 'duration' in outputs: + duration = output_values[outputs.index('duration')] + time_outcsv.writerow([d['trip_id'][0], duration[0]]) |