From 20a1a01cef9d61ce9dd09995f2c811ab5aca2a9d Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 8 May 2015 14:59:44 -0400 Subject: Add model for a network that predicts both time and destination. --- train.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 8d9f4ad..a70bb90 100755 --- a/train.py +++ b/train.py @@ -67,7 +67,6 @@ def setup_test_stream(req_vars): test = transformers.TaxiAddDateTime(test) test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test) - test = transformers.TaxiAddLast(config.n_begin_end_pts, test) test = transformers.Select(test, tuple(req_vars)) test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) @@ -96,13 +95,17 @@ def main(): # step_rule=AdaDelta(decay_rate=0.5), step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum), params=params) + + plot_vars = [['valid_' + x.name for x in model.monitor]] + # plot_vars = ['valid_cost'] + print "Plot: ", plot_vars extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000), DataStreamMonitoring(model.monitor, valid_stream, prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), - Plot(model_name, channels=[['valid_cost']], every_n_batches=1000), + Plot(model_name, channels=plot_vars, every_n_batches=1000), # Checkpoint('model.pkl', every_n_batches=100), Dump('model_data/' + model_name, every_n_batches=1000), LoadFromDump('model_data/' + model_name), @@ -120,21 +123,24 @@ def main(): # Produce an output on the test data test_stream = setup_test_stream(req_vars_test) - outfile = open("output/test-output-%s.csv" % model_name, "w") - outcsv = csv.writer(outfile) - if model.pred_vars == ['travel_time']: - outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) - for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): - time = out['outputs'] - for i, trip in enumerate(out['trip_id']): - outcsv.writerow([trip, int(time[i])]) - else: - outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) - for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): - dest = out['outputs'] - for i, trip in enumerate(out['trip_id']): - outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])]) - outfile.close() + if 'destination_longitude' in model.pred_vars: + dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w") + dest_outcsv = csv.writer(dest_outfile) + dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) + if 'travel_time' in model.pred_vars: + time_outfile = open("output/test-time-output-%s.csv" % model_name, "w") + time_outcsv = csv.writer(time_outfile) + time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) + + for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']): + outputs = out['outputs'] + for i, trip in enumerate(out['trip_id']): + if model.pred_vars == ['travel_time']: + time_outcsv.writerow([trip, int(outputs[i])]) + else: + dest_outcsv.writerow([trip, repr(outputs[i, 0]), repr(outputs[i, 1])]) + if 'travel_time' in model.pred_vars: + time_outcsv.writerow([trip, int(outputs[i, 2])]) if __name__ == "__main__": -- cgit v1.2.3