aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train.py b/train.py
index 1b39671..dcf3fcd 100644
--- a/train.py
+++ b/train.py
@@ -38,7 +38,8 @@ if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
sys.exit(1)
- config = importlib.import_module(sys.argv[1])
+ model_name = sys.argv[1]
+ config = importlib.import_module(model_name)
def setup_train_stream():
@@ -107,8 +108,8 @@ def main():
every_n_batches=1000),
Printing(every_n_batches=1000),
# Checkpoint('model.pkl', every_n_batches=100),
- Dump('taxi_model', every_n_batches=1000),
- LoadFromDump('taxi_model'),
+ Dump('model_data/' + model_name, every_n_batches=1000),
+ LoadFromDump('model_data/' + model_name),
FinishAfter(after_epoch=5)
]