diff options
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 14 |
1 files changed, 9 insertions, 5 deletions
@@ -15,6 +15,12 @@ else: client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))} +def get_client_id(n): + if n in client_ids: + return client_ids[n] + else: + return 0 + porto_center = numpy.array([[ -8.61612, 41.1573]], dtype=theano.config.floatX) data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype=theano.config.floatX)) @@ -127,7 +133,7 @@ class TaxiData(Dataset): taxi_columns = [ ("trip_id", lambda x: x), ("call_type", CallType.from_data), - ("origin_call", lambda x: 0 if x == '' or x == 'NA' else client_ids[int(x)]), + ("origin_call", lambda x: 0 if x == '' or x == 'NA' else get_client_id(int(x))), ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)), ("taxi_id", int), ("timestamp", int), @@ -144,13 +150,11 @@ taxi_columns_valid = taxi_columns + [ train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] valid_files=["%s/split/valid.csv" % (DATA_PATH,)] +test_file="%s/test.csv" % (DATA_PATH,) train_data=TaxiData(train_files, taxi_columns) - valid_data = TaxiData(valid_files, taxi_columns_valid) - -# for the moment - will be changed later -test_data = valid_data +test_data = TaxiData(test_file, taxi_columns, has_header=True) def train_it(): return DataIterator(DataStream(train_data)) |