diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-24 16:48:44 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-24 16:48:44 -0400 |
commit | bd2826df73554207c88c5918d86fd9707d9e3753 (patch) | |
tree | d4e02bb0e7431369403dfa03e7d29a96b032e2e4 /data.py | |
parent | d03fb73c8f95c9ec2f4dec25062e13223a709183 (diff) | |
download | taxi-bd2826df73554207c88c5918d86fd9707d9e3753.tar.gz taxi-bd2826df73554207c88c5918d86fd9707d9e3753.zip |
Connect model with data stream
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 6 |
1 files changed, 4 insertions, 2 deletions
@@ -116,8 +116,10 @@ class TaxiData(Dataset): line[8]=map(tuple, ast.literal_eval(line[8])) # polyline return tuple(line) -train_data=TaxiData(PREFIX+'/train.csv') -test_data=TaxiData(PREFIX+'/test.csv') +train_files=["%s/split/train-%02d.csv" % (PREFIX, i) for i in range(100)] +valid_files=["%s/split/valid.csv" % (PREFIX,)] +train_data=TaxiData(train_files) +valid_data=TaxiData(valid_files) def train_it(): return DataIterator(DataStream(train_data)) |