aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 14:59:44 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 15:00:50 -0400
commit20a1a01cef9d61ce9dd09995f2c811ab5aca2a9d (patch)
treec2638b5607820e596b8d7cd46e5137b41b25c61f /train.py
parent0ecac7973fd02f44af9c8bc5765f7c159c94b23a (diff)
downloadtaxi-20a1a01cef9d61ce9dd09995f2c811ab5aca2a9d.tar.gz
taxi-20a1a01cef9d61ce9dd09995f2c811ab5aca2a9d.zip
Add model for a network that predicts both time and destination.
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py40
1 files changed, 23 insertions, 17 deletions
diff --git a/train.py b/train.py
index 8d9f4ad..a70bb90 100755
--- a/train.py
+++ b/train.py
@@ -67,7 +67,6 @@ def setup_test_stream(req_vars):
test = transformers.TaxiAddDateTime(test)
test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test)
- test = transformers.TaxiAddLast(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
@@ -96,13 +95,17 @@ def main():
# step_rule=AdaDelta(decay_rate=0.5),
step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
params=params)
+
+ plot_vars = [['valid_' + x.name for x in model.monitor]]
+ # plot_vars = ['valid_cost']
+ print "Plot: ", plot_vars
extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000),
DataStreamMonitoring(model.monitor, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
- Plot(model_name, channels=[['valid_cost']], every_n_batches=1000),
+ Plot(model_name, channels=plot_vars, every_n_batches=1000),
# Checkpoint('model.pkl', every_n_batches=100),
Dump('model_data/' + model_name, every_n_batches=1000),
LoadFromDump('model_data/' + model_name),
@@ -120,21 +123,24 @@ def main():
# Produce an output on the test data
test_stream = setup_test_stream(req_vars_test)
- outfile = open("output/test-output-%s.csv" % model_name, "w")
- outcsv = csv.writer(outfile)
- if model.pred_vars == ['travel_time']:
- outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
- for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
- time = out['outputs']
- for i, trip in enumerate(out['trip_id']):
- outcsv.writerow([trip, int(time[i])])
- else:
- outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
- for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
- dest = out['outputs']
- for i, trip in enumerate(out['trip_id']):
- outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])])
- outfile.close()
+ if 'destination_longitude' in model.pred_vars:
+ dest_outfile = open("output/test-dest-output-%s.csv" % model_name, "w")
+ dest_outcsv = csv.writer(dest_outfile)
+ dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
+ if 'travel_time' in model.pred_vars:
+ time_outfile = open("output/test-time-output-%s.csv" % model_name, "w")
+ time_outcsv = csv.writer(time_outfile)
+ time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
+
+ for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
+ outputs = out['outputs']
+ for i, trip in enumerate(out['trip_id']):
+ if model.pred_vars == ['travel_time']:
+ time_outcsv.writerow([trip, int(outputs[i])])
+ else:
+ dest_outcsv.writerow([trip, repr(outputs[i, 0]), repr(outputs[i, 1])])
+ if 'travel_time' in model.pred_vars:
+ time_outcsv.writerow([trip, int(outputs[i, 2])])
if __name__ == "__main__":