aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data.py29
-rw-r--r--model.py6
2 files changed, 24 insertions, 11 deletions
diff --git a/data.py b/data.py
index 79b38c7..92efefc 100644
--- a/data.py
+++ b/data.py
@@ -13,22 +13,35 @@ if socket.gethostname() == "adeb.laptop":
else:
DATA_PATH="/data/lisatmp3/auvolat/taxikaggle"
-client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))}
-
-def get_client_id(n):
- if n in client_ids:
- return client_ids[n]
- else:
- return 0
+H5DATA_PATH = '/data/lisatmp3/simonet/taxi/data.hdf5'
porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))
-n_clients = 57124 #57105
+n_clients = 57124
+n_train_clients = 57105
n_stands = 63
dataset_size = 1710670
+# ---- Read client IDs and create reverse dictionnary
+
+def make_client_ids():
+ f = h5py.File(H5DATA_PATH, "r")
+ l = f['uniq_origin_call']
+ r = {}
+ for i in range(l.shape[0]):
+ r[l[i]] = i
+ return r
+
+client_ids = make_client_ids()
+
+def get_client_id(n):
+ if n in client_ids:
+ return client_ids[n]
+ else:
+ return 0
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
diff --git a/model.py b/model.py
index 065db7e..a44a6cf 100644
--- a/model.py
+++ b/model.py
@@ -46,7 +46,7 @@ if __name__ == "__main__":
def setup_stream():
# Load the training and test data
- train = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
+ train = H5PYDataset(H5DATA_PATH,
which_set='train',
subset=slice(0, data.dataset_size - config.n_valid),
load_in_memory=True)
@@ -59,7 +59,7 @@ def setup_stream():
'destination_latitude', 'destination_longitude'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
- valid = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
+ valid = H5PYDataset(H5DATA_PATH,
which_set='train',
subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
load_in_memory=True)
@@ -106,7 +106,7 @@ def main():
# x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude)
# Define the model
- client_embed_table = LookupTable(length=data.n_clients+1, dim=config.dim_embed, name='client_lookup')
+ client_embed_table = LookupTable(length=data.n_train_clients+1, dim=config.dim_embed, name='client_lookup')
stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup')
mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()],
dims=[config.dim_input] + config.dim_hidden + [config.dim_output])