diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 17:13:08 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 17:13:08 -0400 |
commit | 5f42c01231ccec377196472b6f4682b6afeb878d (patch) | |
tree | 8e0212399a951a57738574084234e7c75b4fe590 /train.py | |
parent | c912ef9424be973b11b4c7b7dbb2d32a8f3a9ab9 (diff) | |
download | taxi-5f42c01231ccec377196472b6f4682b6afeb878d.tar.gz taxi-5f42c01231ccec377196472b6f4682b6afeb878d.zip |
Add model with predefined target classes
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 7 |
1 files changed, 4 insertions, 3 deletions
@@ -38,7 +38,8 @@ if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] sys.exit(1) - config = importlib.import_module(sys.argv[1]) + model_name = sys.argv[1] + config = importlib.import_module(model_name) def setup_train_stream(): @@ -107,8 +108,8 @@ def main(): every_n_batches=1000), Printing(every_n_batches=1000), # Checkpoint('model.pkl', every_n_batches=100), - Dump('taxi_model', every_n_batches=1000), - LoadFromDump('taxi_model'), + Dump('model_data/' + model_name, every_n_batches=1000), + LoadFromDump('model_data/' + model_name), FinishAfter(after_epoch=5) ] |