aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:41:36 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:41:36 -0400
commit43e106e6630030dd34813295fe1d07bb86025402 (patch)
treec3a604d8d023e35532522a18da06e8c25dc251c6 /model.py
parent8b27690c8d77585f173412e5719787c48272674e (diff)
downloadtaxi-43e106e6630030dd34813295fe1d07bb86025402.tar.gz
taxi-43e106e6630030dd34813295fe1d07bb86025402.zip
Fix
Diffstat (limited to 'model.py')
-rw-r--r--model.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/model.py b/model.py
index a44a6cf..aff9fd7 100644
--- a/model.py
+++ b/model.py
@@ -46,7 +46,7 @@ if __name__ == "__main__":
def setup_stream():
# Load the training and test data
- train = H5PYDataset(H5DATA_PATH,
+ train = H5PYDataset(data.H5DATA_PATH,
which_set='train',
subset=slice(0, data.dataset_size - config.n_valid),
load_in_memory=True)
@@ -59,7 +59,7 @@ def setup_stream():
'destination_latitude', 'destination_longitude'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
- valid = H5PYDataset(H5DATA_PATH,
+ valid = H5PYDataset(data.H5DATA_PATH,
which_set='train',
subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
load_in_memory=True)
@@ -164,7 +164,7 @@ def main():
# Checkpoint('model.pkl', every_n_batches=100),
Dump('taxi_model', every_n_batches=1000),
LoadFromDump('taxi_model'),
- FinishAfter(after_epoch=1)
+ FinishAfter(after_epoch=5)
]
main_loop = MainLoop(