diff options
author | Alex Auvolat <alex@adnab.me> | 2015-06-17 14:58:38 -0400 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2015-06-17 14:58:52 -0400 |
commit | 12304944033d20bbc5c1b3f5cb90cf8dedebcdff (patch) | |
tree | 1097585f948d7040416eef58344b4bd194f10b9f /paramsaveload.py | |
parent | 701c407b8c87a9270b31d34ac54e683341be661e (diff) | |
download | text-rnn-12304944033d20bbc5c1b3f5cb90cf8dedebcdff.tar.gz text-rnn-12304944033d20bbc5c1b3f5cb90cf8dedebcdff.zip |
paramsaveload
Diffstat (limited to 'paramsaveload.py')
-rw-r--r-- | paramsaveload.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/paramsaveload.py b/paramsaveload.py new file mode 100644 index 0000000..7181e9a --- /dev/null +++ b/paramsaveload.py @@ -0,0 +1,37 @@ +import logging + +import numpy + +import cPickle + +from blocks.extensions import SimpleExtension + +logging.basicConfig(level='INFO') +logger = logging.getLogger('extensions.SaveLoadParams') + +class SaveLoadParams(SimpleExtension): + def __init__(self, path, model, **kwargs): + super(SaveLoadParams, self).__init__(**kwargs) + + self.path = path + self.model = model + + def do_save(self): + with open(self.path, 'w') as f: + logger.info('Saving parameters to %s...'%self.path) + cPickle.dump(model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + + def do_load(self): + try: + with open(self.path, 'r') as f: + logger.info('Loading parameters from %s...'%self.path) + model.set_parma_values(cPickle.load(f)) + except IOError: + pass + + def do(self, which_callback, *args): + if which_callback == 'before_training': + self.do_load() + else: + self.do_save() + |