diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:16:22 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:16:22 -0400 |
commit | 8b27690c8d77585f173412e5719787c48272674e (patch) | |
tree | c2211d7fc60f9b05e6f1b257dd6093fcf792af29 /data.py | |
parent | fe704521e6cb4d7f32414b55044b6e2240524bf5 (diff) | |
download | taxi-8b27690c8d77585f173412e5719787c48272674e.tar.gz taxi-8b27690c8d77585f173412e5719787c48272674e.zip |
Fix CSV id normalization
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 29 |
1 files changed, 21 insertions, 8 deletions
@@ -13,22 +13,35 @@ if socket.gethostname() == "adeb.laptop": else: DATA_PATH="/data/lisatmp3/auvolat/taxikaggle" -client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))} - -def get_client_id(n): - if n in client_ids: - return client_ids[n] - else: - return 0 +H5DATA_PATH = '/data/lisatmp3/simonet/taxi/data.hdf5' porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX) data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX)) -n_clients = 57124 #57105 +n_clients = 57124 +n_train_clients = 57105 n_stands = 63 dataset_size = 1710670 +# ---- Read client IDs and create reverse dictionnary + +def make_client_ids(): + f = h5py.File(H5DATA_PATH, "r") + l = f['uniq_origin_call'] + r = {} + for i in range(l.shape[0]): + r[l[i]] = i + return r + +client_ids = make_client_ids() + +def get_client_id(n): + if n in client_ids: + return client_ids[n] + else: + return 0 + class CallType(Enum): CENTRAL = 0 STAND = 1 |