diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:16:22 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:16:22 -0400 |
commit | 8b27690c8d77585f173412e5719787c48272674e (patch) | |
tree | c2211d7fc60f9b05e6f1b257dd6093fcf792af29 /model.py | |
parent | fe704521e6cb4d7f32414b55044b6e2240524bf5 (diff) | |
download | taxi-8b27690c8d77585f173412e5719787c48272674e.tar.gz taxi-8b27690c8d77585f173412e5719787c48272674e.zip |
Fix CSV id normalization
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 6 |
1 files changed, 3 insertions, 3 deletions
@@ -46,7 +46,7 @@ if __name__ == "__main__": def setup_stream(): # Load the training and test data - train = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5', + train = H5PYDataset(H5DATA_PATH, which_set='train', subset=slice(0, data.dataset_size - config.n_valid), load_in_memory=True) @@ -59,7 +59,7 @@ def setup_stream(): 'destination_latitude', 'destination_longitude')) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) - valid = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5', + valid = H5PYDataset(H5DATA_PATH, which_set='train', subset=slice(data.dataset_size - config.n_valid, data.dataset_size), load_in_memory=True) @@ -106,7 +106,7 @@ def main(): # x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude) # Define the model - client_embed_table = LookupTable(length=data.n_clients+1, dim=config.dim_embed, name='client_lookup') + client_embed_table = LookupTable(length=data.n_train_clients+1, dim=config.dim_embed, name='client_lookup') stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup') mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], dims=[config.dim_input] + config.dim_hidden + [config.dim_output]) |