aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-24 16:48:44 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-24 16:48:44 -0400
commitbd2826df73554207c88c5918d86fd9707d9e3753 (patch)
treed4e02bb0e7431369403dfa03e7d29a96b032e2e4 /data.py
parentd03fb73c8f95c9ec2f4dec25062e13223a709183 (diff)
downloadtaxi-bd2826df73554207c88c5918d86fd9707d9e3753.tar.gz
taxi-bd2826df73554207c88c5918d86fd9707d9e3753.zip
Connect model with data stream
Diffstat (limited to 'data.py')
-rw-r--r--data.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/data.py b/data.py
index d03e10e..4590a7b 100644
--- a/data.py
+++ b/data.py
@@ -116,8 +116,10 @@ class TaxiData(Dataset):
line[8]=map(tuple, ast.literal_eval(line[8])) # polyline
return tuple(line)
-train_data=TaxiData(PREFIX+'/train.csv')
-test_data=TaxiData(PREFIX+'/test.csv')
+train_files=["%s/split/train-%02d.csv" % (PREFIX, i) for i in range(100)]
+valid_files=["%s/split/valid.csv" % (PREFIX,)]
+train_data=TaxiData(train_files)
+valid_data=TaxiData(valid_files)
def train_it():
return DataIterator(DataStream(train_data))