diff options
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 7 |
1 files changed, 4 insertions, 3 deletions
@@ -38,7 +38,8 @@ if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] sys.exit(1) - config = importlib.import_module(sys.argv[1]) + model_name = sys.argv[1] + config = importlib.import_module(model_name) def setup_train_stream(): @@ -107,8 +108,8 @@ def main(): every_n_batches=1000), Printing(every_n_batches=1000), # Checkpoint('model.pkl', every_n_batches=100), - Dump('taxi_model', every_n_batches=1000), - LoadFromDump('taxi_model'), + Dump('model_data/' + model_name, every_n_batches=1000), + LoadFromDump('model_data/' + model_name), FinishAfter(after_epoch=5) ] |