diff options
Diffstat (limited to 'data/rfc4180.py')
-rw-r--r-- | data/rfc4180.py | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/data/rfc4180.py b/data/rfc4180.py new file mode 100644 index 0000000..b6fe5b1 --- /dev/null +++ b/data/rfc4180.py @@ -0,0 +1,107 @@ +import ast +import csv +import numpy + +from fuel.datasets import Dataset +from fuel.streams import DataStream +from fuel.iterator import DataIterator + +import data +from data.hdf5 import origin_call_normalize, taxi_id_normalize + + +class TaxiData(Dataset): + example_iteration_scheme=None + + class State: + __slots__ = ('file', 'index', 'reader') + + def __init__(self, pathes, columns, has_header=False): + if not isinstance(pathes, list): + pathes=[pathes] + assert len(pathes)>0 + self.columns=columns + self.provides_sources = tuple(map(lambda x: x[0], columns)) + self.pathes=pathes + self.has_header=has_header + super(TaxiData, self).__init__() + + def open(self): + state=self.State() + state.file=open(self.pathes[0]) + state.index=0 + state.reader=csv.reader(state.file) + if self.has_header: + state.reader.next() + return state + + def close(self, state): + state.file.close() + + def reset(self, state): + if state.index==0: + state.file.seek(0) + else: + state.index=0 + state.file.close() + state.file=open(self.pathes[0]) + state.reader=csv.reader(state.file) + return state + + def get_data(self, state, request=None): + if request is not None: + raise ValueError + try: + line=state.reader.next() + except (ValueError, StopIteration): + # print state.index + state.file.close() + state.index+=1 + if state.index>=len(self.pathes): + raise StopIteration + state.file=open(self.pathes[state.index]) + state.reader=csv.reader(state.file) + if self.has_header: + state.reader.next() + return self.get_data(state) + + values = [] + for _, constructor in self.columns: + values.append(constructor(line)) + return tuple(values) + +taxi_columns = [ + ("trip_id", lambda l: l[0]), + ("call_type", lambda l: ord(l[1])-ord('A')), + ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))), + ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])), + ("taxi_id", lambda l: taxi_id_normalize(int(l[4]))), + ("timestamp", lambda l: int(l[5])), + ("day_type", lambda l: ord(l[6])-ord('A')), + ("missing_data", lambda l: l[7][0] == 'T'), + ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))), + ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))), + ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))), +] + +taxi_columns_valid = taxi_columns + [ + ("destination_longitude", lambda l: numpy.float32(float(l[9]))), + ("destination_latitude", lambda l: numpy.float32(float(l[10]))), + ("time", lambda l: int(l[11])), +] + +train_file="%s/train.csv" % data.path +valid_file="%s/valid2-cut.csv" % data.path +test_file="%s/test.csv" % data.path + +train_data=TaxiData(train_file, taxi_columns, has_header=True) +valid_data = TaxiData(valid_file, taxi_columns_valid) +test_data = TaxiData(test_file, taxi_columns, has_header=True) + +valid_trips = [l for l in open("%s/valid2-cut-ids.txt" % data.path)] + +def train_it(): + return DataIterator(DataStream(train_data)) + +def test_it(): + return DataIterator(DataStream(valid_data)) |