aboutsummaryrefslogtreecommitdiff
path: root/test.py
blob: 1a59d49ed857433d2507a34b18fac3283f84acc9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python

import sys
import os
import importlib
import csv

from blocks.dump import load_parameter_values
from blocks.model import Model


if __name__ == "__main__":
    if len(sys.argv) != 2:
        print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
        sys.exit(1)
    model_name = sys.argv[1]
    config = importlib.import_module('.%s' % model_name, 'config')
    model_config = config.Model(config)

    stream = config.Stream(config)
    inputs = stream.inputs()
    outputs = model_config.predict.outputs
    req_vars_test = model_config.predict.inputs + ['trip_id']
    test_stream = stream.test(req_vars_test)

    model = Model(model_config.predict(**inputs))
    parameters = load_parameter_values(os.path.join('model_data', model_name, 'params.npz'))
    model.set_param_values(parameters)

    if 'destination' in outputs:
        dest_outfile = open(os.path.join('output', 'test-dest-output-%s.csv' % model_name), 'w')
        dest_outcsv = csv.writer(dest_outfile)
        dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
    if 'duration' in outputs:
        time_outfile = open(os.path.join('output', 'test-time-output-%s.csv' % model_name), 'w')
        time_outcsv = csv.writer(time_outfile)
        time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])

    function = model.get_theano_function()
    for d in test_stream.get_epoch_iterator(as_dict=True):
        input_values = [d[k.name] for k in model.inputs]
        output_values = function(*input_values)
        if 'destination' in outputs:
            destination = output_values[outputs.index('destination')]
            dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]])
        if 'duration' in outputs:
            duration = output_values[outputs.index('duration')]
            time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))])

    if 'destination' in outputs:
        dest_outfile.close()
    if 'duration' in outputs:
        time_outfile.close()