diff options
Diffstat (limited to 'test.py')
-rwxr-xr-x | test.py | 54 |
1 files changed, 0 insertions, 54 deletions
diff --git a/test.py b/test.py deleted file mode 100755 index 4925b27..0000000 --- a/test.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python - -import cPickle -import sys -import os -import importlib -import csv - -from blocks.model import Model - - -if __name__ == "__main__": - if len(sys.argv) != 2: - print >> sys.stderr, 'Usage: %s config' % sys.argv[0] - sys.exit(1) - model_name = sys.argv[1] - config = importlib.import_module('.%s' % model_name, 'config') - model_config = config.Model(config) - - stream = config.Stream(config) - inputs = stream.inputs() - outputs = model_config.predict.outputs - req_vars_test = model_config.predict.inputs + ['trip_id'] - test_stream = stream.test(req_vars_test) - - model = Model(model_config.predict(**inputs)) - with open(os.path.join('model_data', "{}.pkl".format(model_name))) as f: - parameters = cPickle.load(f) - model.set_param_values(parameters) - - if 'destination' in outputs: - dest_outfile = open(os.path.join('output', 'test-dest-output-%s.csv' % model_name), 'w') - dest_outcsv = csv.writer(dest_outfile) - dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) - if 'duration' in outputs: - time_outfile = open(os.path.join('output', 'test-time-output-%s.csv' % model_name), 'w') - time_outcsv = csv.writer(time_outfile) - time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) - - function = model.get_theano_function() - for d in test_stream.get_epoch_iterator(as_dict=True): - input_values = [d[k.name] for k in model.inputs] - output_values = function(*input_values) - if 'destination' in outputs: - destination = output_values[outputs.index('destination')] - dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]]) - if 'duration' in outputs: - duration = output_values[outputs.index('duration')] - time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))]) - - if 'destination' in outputs: - dest_outfile.close() - if 'duration' in outputs: - time_outfile.close() |