aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/data.py b/data.py
index f1236a5..351c90a 100644
--- a/data.py
+++ b/data.py
@@ -15,6 +15,12 @@ else:
client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))}
+def get_client_id(n):
+ if n in client_ids:
+ return client_ids[n]
+ else:
+ return 0
+
porto_center = numpy.array([[ -8.61612, 41.1573]], dtype=theano.config.floatX)
data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype=theano.config.floatX))
@@ -127,7 +133,7 @@ class TaxiData(Dataset):
taxi_columns = [
("trip_id", lambda x: x),
("call_type", CallType.from_data),
- ("origin_call", lambda x: 0 if x == '' or x == 'NA' else client_ids[int(x)]),
+ ("origin_call", lambda x: 0 if x == '' or x == 'NA' else get_client_id(int(x))),
("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)),
("taxi_id", int),
("timestamp", int),
@@ -144,13 +150,11 @@ taxi_columns_valid = taxi_columns + [
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
+test_file="%s/test.csv" % (DATA_PATH,)
train_data=TaxiData(train_files, taxi_columns)
-
valid_data = TaxiData(valid_files, taxi_columns_valid)
-
-# for the moment - will be changed later
-test_data = valid_data
+test_data = TaxiData(test_file, taxi_columns, has_header=True)
def train_it():
return DataIterator(DataStream(train_data))