From fe704521e6cb4d7f32414b55044b6e2240524bf5 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 29 Apr 2015 19:09:37 -0400 Subject: Fix CSV import (partially) --- data.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) (limited to 'data.py') diff --git a/data.py b/data.py index d38df10..79b38c7 100644 --- a/data.py +++ b/data.py @@ -24,6 +24,11 @@ def get_client_id(n): porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX) data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX)) +n_clients = 57124 #57105 +n_stands = 63 + +dataset_size = 1710670 + class CallType(Enum): CENTRAL = 0 STAND = 1 @@ -126,26 +131,28 @@ class TaxiData(Dataset): return self.get_data(state) values = [] - for idx, (_, constructor) in enumerate(self.columns): - values.append(constructor(line[idx])) + for _, constructor in self.columns: + values.append(constructor(line)) return tuple(values) taxi_columns = [ - ("trip_id", lambda x: x), - ("call_type", CallType.from_data), - ("origin_call", lambda x: 0 if x == '' or x == 'NA' else get_client_id(int(x))), - ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)), - ("taxi_id", int), - ("timestamp", int), - ("day_type", DayType.from_data), - ("missing_data", lambda x: x[0] == 'T'), - ("polyline", lambda x: map(tuple, ast.literal_eval(x))), + ("trip_id", lambda l: l[0]), + ("call_type", lambda l: CallType.from_data(l[1])), + ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))), + ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])), + ("taxi_id", lambda l: int(l[4])), + ("timestamp", lambda l: int(l[5])), + ("day_type", lambda l: DayType.from_data(l[6])), + ("missing_data", lambda l: l[7][0] == 'T'), + ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))), + ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))), + ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))), ] taxi_columns_valid = taxi_columns + [ - ("destination_x", float), - ("destination_y", float), - ("time", int), + ("destination_longitude", lambda l: float(l[9])), + ("destination_latitude", lambda l: float(l[10])), + ("time", lambda l: int(l[11])), ] train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] -- cgit v1.2.3