aboutsummaryrefslogblamecommitdiff
path: root/test.py
blob: bf39b15b94445f832a509f41abe499a40e567a09 (plain) (tree)















































                                                                                            
#!/usr/bin/env python

import sys
import os
import importlib
import csv

from blocks.dump import load_parameter_values
from blocks.model import Model


if __name__ == "__main__":
    if len(sys.argv) != 2:
        print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
        sys.exit(1)
    model_name = sys.argv[1]
    config = importlib.import_module('.%s' % model_name, 'config')
    model_config = config.Model(config)

    stream = config.Stream(config)
    inputs = stream.inputs()
    outputs = model_config.predict.outputs
    req_vars_test = model_config.predict.inputs + ['trip_id']
    test_stream = stream.test(req_vars_test)

    model = Model(model_config.predict(**inputs))
    parameters = load_parameter_values(os.path.join('model_data', model_name, 'params.npz'))
    model.set_param_values(parameters)

    if 'destination' in outputs:
        dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w")
        dest_outcsv = csv.writer(dest_outfile)
        dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
    if 'duration' in outputs:
        time_outfile = open("output/test-time-output-%s.csv" % model_name, "w")
        time_outcsv = csv.writer(time_outfile)
        time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])

    function = model.get_theano_function()
    for d in test_stream.get_epoch_iterator(as_dict=True):
        input_values = [d[k.name] for k in model.inputs]
        output_values = function(*input_values)
        if 'destination' in outputs:
            destination = output_values[outputs.index('destination')]
            dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]])
        if 'duration' in outputs:
            duration = output_values[outputs.index('duration')]
            time_outcsv.writerow([d['trip_id'][0], duration[0]])