diff options
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 20 |
1 files changed, 15 insertions, 5 deletions
@@ -30,9 +30,7 @@ dataset_size = 1710670 def make_client_ids(): f = h5py.File(H5DATA_PATH, "r") l = f['unique_origin_call'] - r = {} - for i in range(l.shape[0]): - r[l[i]] = i + r = {l[i]: i for i in range(l.shape[0])} return r client_ids = make_client_ids() @@ -43,6 +41,18 @@ def get_client_id(n): else: return 0 +# ---- Read taxi IDs and create reverse dictionnary + +def make_taxi_ids(): + f = h5py.File(H5DATA_PATH, "r") + l = f['unique_taxi_id'] + r = {l[i]: i for i in range(l.shape[0])} + return r + +taxi_ids = make_taxi_ids() + +# ---- Enum types + class CallType(Enum): CENTRAL = 0 STAND = 1 @@ -154,9 +164,9 @@ taxi_columns = [ ("call_type", lambda l: CallType.from_data(l[1])), ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))), ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])), - ("taxi_id", lambda l: int(l[4])), + ("taxi_id", lambda l: taxi_ids[int(l[4])]), ("timestamp", lambda l: int(l[5])), - ("day_type", lambda l: DayType.from_data(l[6])), + ("day_type", lambda l: ord(l[6])-ord('A')), ("missing_data", lambda l: l[7][0] == 'T'), ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))), ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))), |