aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/data.py b/data.py
index 5ebdcde..d2c4f77 100644
--- a/data.py
+++ b/data.py
@@ -131,10 +131,16 @@ taxi_columns = [
("polyline", lambda x: map(tuple, ast.literal_eval(x))),
]
+taxi_columns_valid = taxi_columns + [
+ ("destination_x", float),
+ ("destination_y", float),
+ ("time", int),
+]
+
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
train_data=TaxiData(train_files, taxi_columns)
-valid_data=TaxiData(valid_files, taxi_columns)
+valid_data=TaxiData(valid_files, taxi_columns_valid)
def train_it():
return DataIterator(DataStream(train_data))