aboutsummaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/rfc4180.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/data/rfc4180.py b/data/rfc4180.py
index b6fe5b1..db86830 100644
--- a/data/rfc4180.py
+++ b/data/rfc4180.py
@@ -1,6 +1,7 @@
import ast
import csv
import numpy
+import os
from fuel.datasets import Dataset
from fuel.streams import DataStream
@@ -90,15 +91,16 @@ taxi_columns_valid = taxi_columns + [
("time", lambda l: int(l[11])),
]
-train_file="%s/train.csv" % data.path
-valid_file="%s/valid2-cut.csv" % data.path
-test_file="%s/test.csv" % data.path
+train_file = os.path.join(data.path, 'train.csv')
+valid_file = os.path.join(data.path, 'valid2-cut.csv')
+test_file = os.path.join(data.path, 'test.csv')
train_data=TaxiData(train_file, taxi_columns, has_header=True)
valid_data = TaxiData(valid_file, taxi_columns_valid)
test_data = TaxiData(test_file, taxi_columns, has_header=True)
-valid_trips = [l for l in open("%s/valid2-cut-ids.txt" % data.path)]
+with open(os.path.join(data.path, 'valid2-cut-ids.txt')) as f:
+ valid_trips = [l for l in f]
def train_it():
return DataIterator(DataStream(train_data))