diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 12:49:03 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 12:49:18 -0400 |
commit | a4b190516d00428b1d8a81686a3291e5fa5f9865 (patch) | |
tree | 230f04cbb664d4f7138ca4f22839e6bf501b32be /ext_saveload.py | |
parent | 859bee7196c78e9828d9182b5fea2ad2eab01f1d (diff) | |
download | taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.tar.gz taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.zip |
Make the testing into an extension run at each validation
Diffstat (limited to 'ext_saveload.py')
-rw-r--r-- | ext_saveload.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/ext_saveload.py b/ext_saveload.py new file mode 100644 index 0000000..cc7c47a --- /dev/null +++ b/ext_saveload.py @@ -0,0 +1,33 @@ +import cPickle +import logging + +from blocks.extensions import SimpleExtension + +logger = logging.getLogger(__name__) + +class SaveLoadParams(SimpleExtension): + def __init__(self, path, model, **kwargs): + super(SaveLoadParams, self).__init__(**kwargs) + + self.path = path + self.model = model + + def do_save(self): + with open(self.path, 'w') as f: + logger.info('Saving parameters to %s...'%self.path) + cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + logger.info('Done saving.') + + def do_load(self): + try: + with open(self.path, 'r') as f: + logger.info('Loading parameters from %s...'%self.path) + self.model.set_param_values(cPickle.load(f)) + except IOError: + pass + + def do(self, which_callback, *args): + if which_callback == 'before_training': + self.do_load() + else: + self.do_save() |