aboutsummaryrefslogtreecommitdiff
path: root/ext_saveload.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 12:49:03 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 12:49:18 -0400
commita4b190516d00428b1d8a81686a3291e5fa5f9865 (patch)
tree230f04cbb664d4f7138ca4f22839e6bf501b32be /ext_saveload.py
parent859bee7196c78e9828d9182b5fea2ad2eab01f1d (diff)
downloadtaxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.tar.gz
taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.zip
Make the testing into an extension run at each validation
Diffstat (limited to 'ext_saveload.py')
-rw-r--r--ext_saveload.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/ext_saveload.py b/ext_saveload.py
new file mode 100644
index 0000000..cc7c47a
--- /dev/null
+++ b/ext_saveload.py
@@ -0,0 +1,33 @@
+import cPickle
+import logging
+
+from blocks.extensions import SimpleExtension
+
+logger = logging.getLogger(__name__)
+
+class SaveLoadParams(SimpleExtension):
+ def __init__(self, path, model, **kwargs):
+ super(SaveLoadParams, self).__init__(**kwargs)
+
+ self.path = path
+ self.model = model
+
+ def do_save(self):
+ with open(self.path, 'w') as f:
+ logger.info('Saving parameters to %s...'%self.path)
+ cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+ logger.info('Done saving.')
+
+ def do_load(self):
+ try:
+ with open(self.path, 'r') as f:
+ logger.info('Loading parameters from %s...'%self.path)
+ self.model.set_param_values(cPickle.load(f))
+ except IOError:
+ pass
+
+ def do(self, which_callback, *args):
+ if which_callback == 'before_training':
+ self.do_load()
+ else:
+ self.do_save()