diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:41:36 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 19:41:36 -0400 |
commit | 43e106e6630030dd34813295fe1d07bb86025402 (patch) | |
tree | c3a604d8d023e35532522a18da06e8c25dc251c6 /model.py | |
parent | 8b27690c8d77585f173412e5719787c48272674e (diff) | |
download | taxi-43e106e6630030dd34813295fe1d07bb86025402.tar.gz taxi-43e106e6630030dd34813295fe1d07bb86025402.zip |
Fix
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 6 |
1 files changed, 3 insertions, 3 deletions
@@ -46,7 +46,7 @@ if __name__ == "__main__": def setup_stream(): # Load the training and test data - train = H5PYDataset(H5DATA_PATH, + train = H5PYDataset(data.H5DATA_PATH, which_set='train', subset=slice(0, data.dataset_size - config.n_valid), load_in_memory=True) @@ -59,7 +59,7 @@ def setup_stream(): 'destination_latitude', 'destination_longitude')) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) - valid = H5PYDataset(H5DATA_PATH, + valid = H5PYDataset(data.H5DATA_PATH, which_set='train', subset=slice(data.dataset_size - config.n_valid, data.dataset_size), load_in_memory=True) @@ -164,7 +164,7 @@ def main(): # Checkpoint('model.pkl', every_n_batches=100), Dump('taxi_model', every_n_batches=1000), LoadFromDump('taxi_model'), - FinishAfter(after_epoch=1) + FinishAfter(after_epoch=5) ] main_loop = MainLoop( |