diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/rfc4180.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/data/rfc4180.py b/data/rfc4180.py index b6fe5b1..db86830 100644 --- a/data/rfc4180.py +++ b/data/rfc4180.py @@ -1,6 +1,7 @@ import ast import csv import numpy +import os from fuel.datasets import Dataset from fuel.streams import DataStream @@ -90,15 +91,16 @@ taxi_columns_valid = taxi_columns + [ ("time", lambda l: int(l[11])), ] -train_file="%s/train.csv" % data.path -valid_file="%s/valid2-cut.csv" % data.path -test_file="%s/test.csv" % data.path +train_file = os.path.join(data.path, 'train.csv') +valid_file = os.path.join(data.path, 'valid2-cut.csv') +test_file = os.path.join(data.path, 'test.csv') train_data=TaxiData(train_file, taxi_columns, has_header=True) valid_data = TaxiData(valid_file, taxi_columns_valid) test_data = TaxiData(test_file, taxi_columns, has_header=True) -valid_trips = [l for l in open("%s/valid2-cut-ids.txt" % data.path)] +with open(os.path.join(data.path, 'valid2-cut-ids.txt')) as f: + valid_trips = [l for l in f] def train_it(): return DataIterator(DataStream(train_data)) |