aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/data.py b/data.py
index 730a9ab..39603fc 100644
--- a/data.py
+++ b/data.py
@@ -30,9 +30,7 @@ dataset_size = 1710670
def make_client_ids():
f = h5py.File(H5DATA_PATH, "r")
l = f['unique_origin_call']
- r = {}
- for i in range(l.shape[0]):
- r[l[i]] = i
+ r = {l[i]: i for i in range(l.shape[0])}
return r
client_ids = make_client_ids()
@@ -43,6 +41,18 @@ def get_client_id(n):
else:
return 0
+# ---- Read taxi IDs and create reverse dictionnary
+
+def make_taxi_ids():
+ f = h5py.File(H5DATA_PATH, "r")
+ l = f['unique_taxi_id']
+ r = {l[i]: i for i in range(l.shape[0])}
+ return r
+
+taxi_ids = make_taxi_ids()
+
+# ---- Enum types
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
@@ -154,9 +164,9 @@ taxi_columns = [
("call_type", lambda l: CallType.from_data(l[1])),
("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))),
("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
- ("taxi_id", lambda l: int(l[4])),
+ ("taxi_id", lambda l: taxi_ids[int(l[4])]),
("timestamp", lambda l: int(l[5])),
- ("day_type", lambda l: DayType.from_data(l[6])),
+ ("day_type", lambda l: ord(l[6])-ord('A')),
("missing_data", lambda l: l[7][0] == 'T'),
("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),