aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-27 17:27:43 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-27 17:27:43 -0400
commit9a60f6c4e39c09187710608a9e225b6024b34364 (patch)
tree92e43401b6c6d3982081a35ec680b82856ec00c0 /data.py
parent107b3798cca35472e158d94f36a0bd08f3fe1fe8 (diff)
downloadtaxi-9a60f6c4e39c09187710608a9e225b6024b34364.tar.gz
taxi-9a60f6c4e39c09187710608a9e225b6024b34364.zip
Add validation set ; fix lat/lon
Diffstat (limited to 'data.py')
-rw-r--r--data.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/data.py b/data.py
index 5ebdcde..d2c4f77 100644
--- a/data.py
+++ b/data.py
@@ -131,10 +131,16 @@ taxi_columns = [
("polyline", lambda x: map(tuple, ast.literal_eval(x))),
]
+taxi_columns_valid = taxi_columns + [
+ ("destination_x", float),
+ ("destination_y", float),
+ ("time", int),
+]
+
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
train_data=TaxiData(train_files, taxi_columns)
-valid_data=TaxiData(valid_files, taxi_columns)
+valid_data=TaxiData(valid_files, taxi_columns_valid)
def train_it():
return DataIterator(DataStream(train_data))