aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/train.py b/train.py
index 17d5789..0d40f84 100755
--- a/train.py
+++ b/train.py
@@ -70,6 +70,7 @@ class SaveLoadParams(SimpleExtension):
with open(self.path, 'w') as f:
logger.info('Saving parameters to %s...'%self.path)
cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+ logger.info('Done saving.')
def do_load(self):
try:
@@ -153,8 +154,10 @@ if __name__ == "__main__":
Printing(every_n_batches=1000),
SaveLoadParams(dump_path, cg,
- before_training=config.load_model, # before training -> load params
+ before_training=True, # before training -> load params
every_n_batches=1000, # every N batches -> save params
+ after_epoch=True, # after epoch -> save params
+ after_training=True, # after training -> save params
),
]