aboutsummaryrefslogtreecommitdiff
path: root/data/rfc4180.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/rfc4180.py')
-rw-r--r--data/rfc4180.py107
1 files changed, 107 insertions, 0 deletions
diff --git a/data/rfc4180.py b/data/rfc4180.py
new file mode 100644
index 0000000..b6fe5b1
--- /dev/null
+++ b/data/rfc4180.py
@@ -0,0 +1,107 @@
+import ast
+import csv
+import numpy
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.iterator import DataIterator
+
+import data
+from data.hdf5 import origin_call_normalize, taxi_id_normalize
+
+
+class TaxiData(Dataset):
+ example_iteration_scheme=None
+
+ class State:
+ __slots__ = ('file', 'index', 'reader')
+
+ def __init__(self, pathes, columns, has_header=False):
+ if not isinstance(pathes, list):
+ pathes=[pathes]
+ assert len(pathes)>0
+ self.columns=columns
+ self.provides_sources = tuple(map(lambda x: x[0], columns))
+ self.pathes=pathes
+ self.has_header=has_header
+ super(TaxiData, self).__init__()
+
+ def open(self):
+ state=self.State()
+ state.file=open(self.pathes[0])
+ state.index=0
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ return state
+
+ def close(self, state):
+ state.file.close()
+
+ def reset(self, state):
+ if state.index==0:
+ state.file.seek(0)
+ else:
+ state.index=0
+ state.file.close()
+ state.file=open(self.pathes[0])
+ state.reader=csv.reader(state.file)
+ return state
+
+ def get_data(self, state, request=None):
+ if request is not None:
+ raise ValueError
+ try:
+ line=state.reader.next()
+ except (ValueError, StopIteration):
+ # print state.index
+ state.file.close()
+ state.index+=1
+ if state.index>=len(self.pathes):
+ raise StopIteration
+ state.file=open(self.pathes[state.index])
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ return self.get_data(state)
+
+ values = []
+ for _, constructor in self.columns:
+ values.append(constructor(line))
+ return tuple(values)
+
+taxi_columns = [
+ ("trip_id", lambda l: l[0]),
+ ("call_type", lambda l: ord(l[1])-ord('A')),
+ ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))),
+ ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
+ ("taxi_id", lambda l: taxi_id_normalize(int(l[4]))),
+ ("timestamp", lambda l: int(l[5])),
+ ("day_type", lambda l: ord(l[6])-ord('A')),
+ ("missing_data", lambda l: l[7][0] == 'T'),
+ ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
+ ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
+ ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
+]
+
+taxi_columns_valid = taxi_columns + [
+ ("destination_longitude", lambda l: numpy.float32(float(l[9]))),
+ ("destination_latitude", lambda l: numpy.float32(float(l[10]))),
+ ("time", lambda l: int(l[11])),
+]
+
+train_file="%s/train.csv" % data.path
+valid_file="%s/valid2-cut.csv" % data.path
+test_file="%s/test.csv" % data.path
+
+train_data=TaxiData(train_file, taxi_columns, has_header=True)
+valid_data = TaxiData(valid_file, taxi_columns_valid)
+test_data = TaxiData(test_file, taxi_columns, has_header=True)
+
+valid_trips = [l for l in open("%s/valid2-cut-ids.txt" % data.path)]
+
+def train_it():
+ return DataIterator(DataStream(train_data))
+
+def test_it():
+ return DataIterator(DataStream(valid_data))