aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/data.py b/data.py
index d03e10e..4590a7b 100644
--- a/data.py
+++ b/data.py
@@ -116,8 +116,10 @@ class TaxiData(Dataset):
line[8]=map(tuple, ast.literal_eval(line[8])) # polyline
return tuple(line)
-train_data=TaxiData(PREFIX+'/train.csv')
-test_data=TaxiData(PREFIX+'/test.csv')
+train_files=["%s/split/train-%02d.csv" % (PREFIX, i) for i in range(100)]
+valid_files=["%s/split/valid.csv" % (PREFIX,)]
+train_data=TaxiData(train_files)
+valid_data=TaxiData(valid_files)
def train_it():
return DataIterator(DataStream(train_data))