From 1f2ff96e6480a62089fcac35154a956c218ed678 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Tue, 5 May 2015 21:55:13 -0400 Subject: Clean data module and generalize use of hdf5. --- data.py | 196 ---------------------------------------------------------------- 1 file changed, 196 deletions(-) delete mode 100644 data.py (limited to 'data.py') diff --git a/data.py b/data.py deleted file mode 100644 index 42ebe1c..0000000 --- a/data.py +++ /dev/null @@ -1,196 +0,0 @@ -import ast, csv -import socket -import fuel -import numpy -import h5py -from enum import Enum -from fuel.datasets import Dataset -from fuel.streams import DataStream -from fuel.iterator import DataIterator -import theano - -if socket.gethostname() == "adeb.laptop": - DATA_PATH = "/Users/adeb/data/taxi" -else: - DATA_PATH="/data/lisatmp3/auvolat/taxikaggle" - -H5DATA_PATH = '/data/lisatmp3/simonet/taxi/data.hdf5' - -porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX) -data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX)) - -n_clients = 57124 -n_train_clients = 57105 -n_stands = 63 - -dataset_size = 1710670 - -# ---- Read client IDs and create reverse dictionnary - -def make_client_ids(): - f = h5py.File(H5DATA_PATH, "r") - l = f['unique_origin_call'] - r = {l[i]: i for i in range(l.shape[0])} - return r - -client_ids = make_client_ids() - -def get_client_id(n): - if n in client_ids and client_ids[n] <= n_train_clients: - return client_ids[n] - else: - return 0 - -# ---- Read taxi IDs and create reverse dictionnary - -def make_taxi_ids(): - f = h5py.File(H5DATA_PATH, "r") - l = f['unique_taxi_id'] - r = {l[i]: i for i in range(l.shape[0])} - return r - -taxi_ids = make_taxi_ids() - -# ---- Enum types - -class CallType(Enum): - CENTRAL = 0 - STAND = 1 - STREET = 2 - - @classmethod - def from_data(cls, val): - if val=='A': - return cls.CENTRAL - elif val=='B': - return cls.STAND - elif val=='C': - return cls.STREET - - @classmethod - def to_data(cls, val): - if val==cls.CENTRAL: - return 'A' - elif val==cls.STAND: - return 'B' - elif val==cls.STREET: - return 'C' - -class DayType(Enum): - NORMAL = 0 - HOLIDAY = 1 - HOLIDAY_EVE = 2 - - @classmethod - def from_data(cls, val): - if val=='A': - return cls.NORMAL - elif val=='B': - return cls.HOLIDAY - elif val=='C': - return cls.HOLIDAY_EVE - - @classmethod - def to_data(cls, val): - if val==cls.NORMAL: - return 'A' - elif val==cls.HOLIDAY: - return 'B' - elif val==cls.HOLIDAY_EVE: - return 'C' - -class TaxiData(Dataset): - example_iteration_scheme=None - - class State: - __slots__ = ('file', 'index', 'reader') - - def __init__(self, pathes, columns, has_header=False): - if not isinstance(pathes, list): - pathes=[pathes] - assert len(pathes)>0 - self.columns=columns - self.provides_sources = tuple(map(lambda x: x[0], columns)) - self.pathes=pathes - self.has_header=has_header - super(TaxiData, self).__init__() - - def open(self): - state=self.State() - state.file=open(self.pathes[0]) - state.index=0 - state.reader=csv.reader(state.file) - if self.has_header: - state.reader.next() - return state - - def close(self, state): - state.file.close() - - def reset(self, state): - if state.index==0: - state.file.seek(0) - else: - state.index=0 - state.file.close() - state.file=open(self.pathes[0]) - state.reader=csv.reader(state.file) - return state - - def get_data(self, state, request=None): - if request is not None: - raise ValueError - try: - line=state.reader.next() - except (ValueError, StopIteration): - # print state.index - state.file.close() - state.index+=1 - if state.index>=len(self.pathes): - raise StopIteration - state.file=open(self.pathes[state.index]) - state.reader=csv.reader(state.file) - if self.has_header: - state.reader.next() - return self.get_data(state) - - values = [] - for _, constructor in self.columns: - values.append(constructor(line)) - return tuple(values) - -taxi_columns = [ - ("trip_id", lambda l: l[0]), - ("call_type", lambda l: CallType.from_data(l[1])), - ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))), - ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])), - ("taxi_id", lambda l: taxi_ids[int(l[4])]), - ("timestamp", lambda l: int(l[5])), - ("day_type", lambda l: ord(l[6])-ord('A')), - ("missing_data", lambda l: l[7][0] == 'T'), - ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))), - ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))), - ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))), -] - -taxi_columns_valid = taxi_columns + [ - ("destination_longitude", lambda l: numpy.float32(float(l[9]))), - ("destination_latitude", lambda l: numpy.float32(float(l[10]))), - ("time", lambda l: int(l[11])), -] - -valid_files=["%s/valid2-cut.csv" % (DATA_PATH,)] -test_file="%s/test.csv" % (DATA_PATH,) - -valid_data = TaxiData(valid_files, taxi_columns_valid) -test_data = TaxiData(test_file, taxi_columns, has_header=True) - -valid_trips = [l for l in open(DATA_PATH + "/valid2-cut-ids.txt")] - -def train_it(): - return DataIterator(DataStream(train_data)) - -def test_it(): - return DataIterator(DataStream(valid_data)) - - -- cgit v1.2.3