aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:09:37 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:09:37 -0400
commitfe704521e6cb4d7f32414b55044b6e2240524bf5 (patch)
treec6b02de77fc274a02f135ad6a71b586d0cb591a0 /data.py
parente28390de61b23882f6e1069d565a2137825c2662 (diff)
downloadtaxi-fe704521e6cb4d7f32414b55044b6e2240524bf5.tar.gz
taxi-fe704521e6cb4d7f32414b55044b6e2240524bf5.zip
Fix CSV import (partially)
Diffstat (limited to 'data.py')
-rw-r--r--data.py35
1 files changed, 21 insertions, 14 deletions
diff --git a/data.py b/data.py
index d38df10..79b38c7 100644
--- a/data.py
+++ b/data.py
@@ -24,6 +24,11 @@ def get_client_id(n):
porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))
+n_clients = 57124 #57105
+n_stands = 63
+
+dataset_size = 1710670
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
@@ -126,26 +131,28 @@ class TaxiData(Dataset):
return self.get_data(state)
values = []
- for idx, (_, constructor) in enumerate(self.columns):
- values.append(constructor(line[idx]))
+ for _, constructor in self.columns:
+ values.append(constructor(line))
return tuple(values)
taxi_columns = [
- ("trip_id", lambda x: x),
- ("call_type", CallType.from_data),
- ("origin_call", lambda x: 0 if x == '' or x == 'NA' else get_client_id(int(x))),
- ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)),
- ("taxi_id", int),
- ("timestamp", int),
- ("day_type", DayType.from_data),
- ("missing_data", lambda x: x[0] == 'T'),
- ("polyline", lambda x: map(tuple, ast.literal_eval(x))),
+ ("trip_id", lambda l: l[0]),
+ ("call_type", lambda l: CallType.from_data(l[1])),
+ ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))),
+ ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
+ ("taxi_id", lambda l: int(l[4])),
+ ("timestamp", lambda l: int(l[5])),
+ ("day_type", lambda l: DayType.from_data(l[6])),
+ ("missing_data", lambda l: l[7][0] == 'T'),
+ ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
+ ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
+ ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
]
taxi_columns_valid = taxi_columns + [
- ("destination_x", float),
- ("destination_y", float),
- ("time", int),
+ ("destination_longitude", lambda l: float(l[9])),
+ ("destination_latitude", lambda l: float(l[10])),
+ ("time", lambda l: int(l[11])),
]
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]