aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:16:22 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:16:22 -0400
commit8b27690c8d77585f173412e5719787c48272674e (patch)
treec2211d7fc60f9b05e6f1b257dd6093fcf792af29 /model.py
parentfe704521e6cb4d7f32414b55044b6e2240524bf5 (diff)
downloadtaxi-8b27690c8d77585f173412e5719787c48272674e.tar.gz
taxi-8b27690c8d77585f173412e5719787c48272674e.zip
Fix CSV id normalization
Diffstat (limited to 'model.py')
-rw-r--r--model.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/model.py b/model.py
index 065db7e..a44a6cf 100644
--- a/model.py
+++ b/model.py
@@ -46,7 +46,7 @@ if __name__ == "__main__":
def setup_stream():
# Load the training and test data
- train = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
+ train = H5PYDataset(H5DATA_PATH,
which_set='train',
subset=slice(0, data.dataset_size - config.n_valid),
load_in_memory=True)
@@ -59,7 +59,7 @@ def setup_stream():
'destination_latitude', 'destination_longitude'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
- valid = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
+ valid = H5PYDataset(H5DATA_PATH,
which_set='train',
subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
load_in_memory=True)
@@ -106,7 +106,7 @@ def main():
# x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude)
# Define the model
- client_embed_table = LookupTable(length=data.n_clients+1, dim=config.dim_embed, name='client_lookup')
+ client_embed_table = LookupTable(length=data.n_train_clients+1, dim=config.dim_embed, name='client_lookup')
stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup')
mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()],
dims=[config.dim_input] + config.dim_hidden + [config.dim_output])