summaryrefslogtreecommitdiff
path: root/paramsaveload.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2015-06-17 14:58:38 -0400
committerAlex Auvolat <alex@adnab.me>2015-06-17 14:58:52 -0400
commit12304944033d20bbc5c1b3f5cb90cf8dedebcdff (patch)
tree1097585f948d7040416eef58344b4bd194f10b9f /paramsaveload.py
parent701c407b8c87a9270b31d34ac54e683341be661e (diff)
downloadtext-rnn-12304944033d20bbc5c1b3f5cb90cf8dedebcdff.tar.gz
text-rnn-12304944033d20bbc5c1b3f5cb90cf8dedebcdff.zip
paramsaveload
Diffstat (limited to 'paramsaveload.py')
-rw-r--r--paramsaveload.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/paramsaveload.py b/paramsaveload.py
new file mode 100644
index 0000000..7181e9a
--- /dev/null
+++ b/paramsaveload.py
@@ -0,0 +1,37 @@
+import logging
+
+import numpy
+
+import cPickle
+
+from blocks.extensions import SimpleExtension
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger('extensions.SaveLoadParams')
+
+class SaveLoadParams(SimpleExtension):
+ def __init__(self, path, model, **kwargs):
+ super(SaveLoadParams, self).__init__(**kwargs)
+
+ self.path = path
+ self.model = model
+
+ def do_save(self):
+ with open(self.path, 'w') as f:
+ logger.info('Saving parameters to %s...'%self.path)
+ cPickle.dump(model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+
+ def do_load(self):
+ try:
+ with open(self.path, 'r') as f:
+ logger.info('Loading parameters from %s...'%self.path)
+ model.set_parma_values(cPickle.load(f))
+ except IOError:
+ pass
+
+ def do(self, which_callback, *args):
+ if which_callback == 'before_training':
+ self.do_load()
+ else:
+ self.do_save()
+