diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 17:27:43 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 17:27:43 -0400 |
commit | 9a60f6c4e39c09187710608a9e225b6024b34364 (patch) | |
tree | 92e43401b6c6d3982081a35ec680b82856ec00c0 /data.py | |
parent | 107b3798cca35472e158d94f36a0bd08f3fe1fe8 (diff) | |
download | taxi-9a60f6c4e39c09187710608a9e225b6024b34364.tar.gz taxi-9a60f6c4e39c09187710608a9e225b6024b34364.zip |
Add validation set ; fix lat/lon
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 8 |
1 files changed, 7 insertions, 1 deletions
@@ -131,10 +131,16 @@ taxi_columns = [ ("polyline", lambda x: map(tuple, ast.literal_eval(x))), ] +taxi_columns_valid = taxi_columns + [ + ("destination_x", float), + ("destination_y", float), + ("time", int), +] + train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] valid_files=["%s/split/valid.csv" % (DATA_PATH,)] train_data=TaxiData(train_files, taxi_columns) -valid_data=TaxiData(valid_files, taxi_columns) +valid_data=TaxiData(valid_files, taxi_columns_valid) def train_it(): return DataIterator(DataStream(train_data)) |