aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 17:13:08 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 17:13:08 -0400
commit5f42c01231ccec377196472b6f4682b6afeb878d (patch)
tree8e0212399a951a57738574084234e7c75b4fe590 /train.py
parentc912ef9424be973b11b4c7b7dbb2d32a8f3a9ab9 (diff)
downloadtaxi-5f42c01231ccec377196472b6f4682b6afeb878d.tar.gz
taxi-5f42c01231ccec377196472b6f4682b6afeb878d.zip
Add model with predefined target classes
Diffstat (limited to 'train.py')
-rw-r--r--train.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train.py b/train.py
index 1b39671..dcf3fcd 100644
--- a/train.py
+++ b/train.py
@@ -38,7 +38,8 @@ if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
sys.exit(1)
- config = importlib.import_module(sys.argv[1])
+ model_name = sys.argv[1]
+ config = importlib.import_module(model_name)
def setup_train_stream():
@@ -107,8 +108,8 @@ def main():
every_n_batches=1000),
Printing(every_n_batches=1000),
# Checkpoint('model.pkl', every_n_batches=100),
- Dump('taxi_model', every_n_batches=1000),
- LoadFromDump('taxi_model'),
+ Dump('model_data/' + model_name, every_n_batches=1000),
+ LoadFromDump('model_data/' + model_name),
FinishAfter(after_epoch=5)
]