From 8b27690c8d77585f173412e5719787c48272674e Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 29 Apr 2015 19:16:22 -0400 Subject: Fix CSV id normalization --- data.py | 29 +++++++++++++++++++++-------- model.py | 6 +++--- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/data.py b/data.py index 79b38c7..92efefc 100644 --- a/data.py +++ b/data.py @@ -13,22 +13,35 @@ if socket.gethostname() == "adeb.laptop": else: DATA_PATH="/data/lisatmp3/auvolat/taxikaggle" -client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))} - -def get_client_id(n): - if n in client_ids: - return client_ids[n] - else: - return 0 +H5DATA_PATH = '/data/lisatmp3/simonet/taxi/data.hdf5' porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX) data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX)) -n_clients = 57124 #57105 +n_clients = 57124 +n_train_clients = 57105 n_stands = 63 dataset_size = 1710670 +# ---- Read client IDs and create reverse dictionnary + +def make_client_ids(): + f = h5py.File(H5DATA_PATH, "r") + l = f['uniq_origin_call'] + r = {} + for i in range(l.shape[0]): + r[l[i]] = i + return r + +client_ids = make_client_ids() + +def get_client_id(n): + if n in client_ids: + return client_ids[n] + else: + return 0 + class CallType(Enum): CENTRAL = 0 STAND = 1 diff --git a/model.py b/model.py index 065db7e..a44a6cf 100644 --- a/model.py +++ b/model.py @@ -46,7 +46,7 @@ if __name__ == "__main__": def setup_stream(): # Load the training and test data - train = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5', + train = H5PYDataset(H5DATA_PATH, which_set='train', subset=slice(0, data.dataset_size - config.n_valid), load_in_memory=True) @@ -59,7 +59,7 @@ def setup_stream(): 'destination_latitude', 'destination_longitude')) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) - valid = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5', + valid = H5PYDataset(H5DATA_PATH, which_set='train', subset=slice(data.dataset_size - config.n_valid, data.dataset_size), load_in_memory=True) @@ -106,7 +106,7 @@ def main(): # x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude) # Define the model - client_embed_table = LookupTable(length=data.n_clients+1, dim=config.dim_embed, name='client_lookup') + client_embed_table = LookupTable(length=data.n_train_clients+1, dim=config.dim_embed, name='client_lookup') stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup') mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], dims=[config.dim_input] + config.dim_hidden + [config.dim_output]) -- cgit v1.2.3