diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-08 14:59:44 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-08 15:00:50 -0400 |
commit | 20a1a01cef9d61ce9dd09995f2c811ab5aca2a9d (patch) | |
tree | c2638b5607820e596b8d7cd46e5137b41b25c61f /train.py | |
parent | 0ecac7973fd02f44af9c8bc5765f7c159c94b23a (diff) | |
download | taxi-20a1a01cef9d61ce9dd09995f2c811ab5aca2a9d.tar.gz taxi-20a1a01cef9d61ce9dd09995f2c811ab5aca2a9d.zip |
Add model for a network that predicts both time and destination.
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 40 |
1 files changed, 23 insertions, 17 deletions
@@ -67,7 +67,6 @@ def setup_test_stream(req_vars): test = transformers.TaxiAddDateTime(test) test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test) - test = transformers.TaxiAddLast(config.n_begin_end_pts, test) test = transformers.Select(test, tuple(req_vars)) test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) @@ -96,13 +95,17 @@ def main(): # step_rule=AdaDelta(decay_rate=0.5), step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum), params=params) + + plot_vars = [['valid_' + x.name for x in model.monitor]] + # plot_vars = ['valid_cost'] + print "Plot: ", plot_vars extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000), DataStreamMonitoring(model.monitor, valid_stream, prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), - Plot(model_name, channels=[['valid_cost']], every_n_batches=1000), + Plot(model_name, channels=plot_vars, every_n_batches=1000), # Checkpoint('model.pkl', every_n_batches=100), Dump('model_data/' + model_name, every_n_batches=1000), LoadFromDump('model_data/' + model_name), @@ -120,21 +123,24 @@ def main(): # Produce an output on the test data test_stream = setup_test_stream(req_vars_test) - outfile = open("output/test-output-%s.csv" % model_name, "w") - outcsv = csv.writer(outfile) - if model.pred_vars == ['travel_time']: - outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) - for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): - time = out['outputs'] - for i, trip in enumerate(out['trip_id']): - outcsv.writerow([trip, int(time[i])]) - else: - outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) - for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): - dest = out['outputs'] - for i, trip in enumerate(out['trip_id']): - outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])]) - outfile.close() + if 'destination_longitude' in model.pred_vars: + dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w") + dest_outcsv = csv.writer(dest_outfile) + dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) + if 'travel_time' in model.pred_vars: + time_outfile = open("output/test-time-output-%s.csv" % model_name, "w") + time_outcsv = csv.writer(time_outfile) + time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) + + for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): + outputs = out['outputs'] + for i, trip in enumerate(out['trip_id']): + if model.pred_vars == ['travel_time']: + time_outcsv.writerow([trip, int(outputs[i])]) + else: + dest_outcsv.writerow([trip, repr(outputs[i, 0]), repr(outputs[i, 1])]) + if 'travel_time' in model.pred_vars: + time_outcsv.writerow([trip, int(outputs[i, 2])]) if __name__ == "__main__": |