aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-28 16:41:46 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-28 16:41:46 -0400
commitc195fd437b76d00ee780cef49903266165f001a7 (patch)
tree6010785da83baa49f7f89dc230e4ef0b0f1994f3 /model.py
parentd58b121de641c0122652bc3d6096a9d0e1048391 (diff)
downloadtaxi-c195fd437b76d00ee780cef49903266165f001a7.tar.gz
taxi-c195fd437b76d00ee780cef49903266165f001a7.zip
Support polylines with <5 points
Diffstat (limited to 'model.py')
-rw-r--r--model.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/model.py b/model.py
index b9cc1a8..405ad47 100644
--- a/model.py
+++ b/model.py
@@ -140,8 +140,8 @@ def main():
extensions=[DataStreamMonitoring([cost, hcost], valid_stream,
prefix='valid',
- every_n_batches=1),
- Printing(every_n_batches=1),
+ every_n_batches=1000),
+ Printing(every_n_batches=1000),
# Dump('taxi_model', every_n_batches=100),
# LoadFromDump('taxi_model'),
]
@@ -163,6 +163,7 @@ def main():
outfile = open("test-output.csv", "w")
outcsv = csv.writer(outfile)
+ outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
dest = out['outputs']
for i, trip in enumerate(out['trip_id']):