aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-28 16:41:46 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-28 16:41:46 -0400
commitc195fd437b76d00ee780cef49903266165f001a7 (patch)
tree6010785da83baa49f7f89dc230e4ef0b0f1994f3 /data.py
parentd58b121de641c0122652bc3d6096a9d0e1048391 (diff)
downloadtaxi-c195fd437b76d00ee780cef49903266165f001a7.tar.gz
taxi-c195fd437b76d00ee780cef49903266165f001a7.zip
Support polylines with <5 points
Diffstat (limited to 'data.py')
-rw-r--r--data.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/data.py b/data.py
index f1236a5..351c90a 100644
--- a/data.py
+++ b/data.py
@@ -15,6 +15,12 @@ else:
client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))}
+def get_client_id(n):
+ if n in client_ids:
+ return client_ids[n]
+ else:
+ return 0
+
porto_center = numpy.array([[ -8.61612, 41.1573]], dtype=theano.config.floatX)
data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype=theano.config.floatX))
@@ -127,7 +133,7 @@ class TaxiData(Dataset):
taxi_columns = [
("trip_id", lambda x: x),
("call_type", CallType.from_data),
- ("origin_call", lambda x: 0 if x == '' or x == 'NA' else client_ids[int(x)]),
+ ("origin_call", lambda x: 0 if x == '' or x == 'NA' else get_client_id(int(x))),
("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)),
("taxi_id", int),
("timestamp", int),
@@ -144,13 +150,11 @@ taxi_columns_valid = taxi_columns + [
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
+test_file="%s/test.csv" % (DATA_PATH,)
train_data=TaxiData(train_files, taxi_columns)
-
valid_data = TaxiData(valid_files, taxi_columns_valid)
-
-# for the moment - will be changed later
-test_data = valid_data
+test_data = TaxiData(test_file, taxi_columns, has_header=True)
def train_it():
return DataIterator(DataStream(train_data))