diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:09:37 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:09:37 -0400 |
commit | fe704521e6cb4d7f32414b55044b6e2240524bf5 (patch) | |
tree | c6b02de77fc274a02f135ad6a71b586d0cb591a0 /data.py | |
parent | e28390de61b23882f6e1069d565a2137825c2662 (diff) | |
download | taxi-fe704521e6cb4d7f32414b55044b6e2240524bf5.tar.gz taxi-fe704521e6cb4d7f32414b55044b6e2240524bf5.zip |
Fix CSV import (partially)
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 35 |
1 files changed, 21 insertions, 14 deletions
@@ -24,6 +24,11 @@ def get_client_id(n): porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX) data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX)) +n_clients = 57124 #57105 +n_stands = 63 + +dataset_size = 1710670 + class CallType(Enum): CENTRAL = 0 STAND = 1 @@ -126,26 +131,28 @@ class TaxiData(Dataset): return self.get_data(state) values = [] - for idx, (_, constructor) in enumerate(self.columns): - values.append(constructor(line[idx])) + for _, constructor in self.columns: + values.append(constructor(line)) return tuple(values) taxi_columns = [ - ("trip_id", lambda x: x), - ("call_type", CallType.from_data), - ("origin_call", lambda x: 0 if x == '' or x == 'NA' else get_client_id(int(x))), - ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)), - ("taxi_id", int), - ("timestamp", int), - ("day_type", DayType.from_data), - ("missing_data", lambda x: x[0] == 'T'), - ("polyline", lambda x: map(tuple, ast.literal_eval(x))), + ("trip_id", lambda l: l[0]), + ("call_type", lambda l: CallType.from_data(l[1])), + ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))), + ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])), + ("taxi_id", lambda l: int(l[4])), + ("timestamp", lambda l: int(l[5])), + ("day_type", lambda l: DayType.from_data(l[6])), + ("missing_data", lambda l: l[7][0] == 'T'), + ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))), + ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))), + ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))), ] taxi_columns_valid = taxi_columns + [ - ("destination_x", float), - ("destination_y", float), - ("time", int), + ("destination_longitude", lambda l: float(l[9])), + ("destination_latitude", lambda l: float(l[10])), + ("time", lambda l: int(l[11])), ] train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] |