diff options
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 8 |
1 files changed, 7 insertions, 1 deletions
@@ -131,10 +131,16 @@ taxi_columns = [ ("polyline", lambda x: map(tuple, ast.literal_eval(x))), ] +taxi_columns_valid = taxi_columns + [ + ("destination_x", float), + ("destination_y", float), + ("time", int), +] + train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] valid_files=["%s/split/valid.csv" % (DATA_PATH,)] train_data=TaxiData(train_files, taxi_columns) -valid_data=TaxiData(valid_files, taxi_columns) +valid_data=TaxiData(valid_files, taxi_columns_valid) def train_it(): return DataIterator(DataStream(train_data)) |