blob: 345cfbb1e42b9fe5f2331786a8ba2f856a20d5ba (
plain) (
tree)
|
|
import logging
import numpy
import cPickle
from blocks.extensions import SimpleExtension
logging.basicConfig(level='INFO')
logger = logging.getLogger('extensions.SaveLoadParams')
class SaveLoadParams(SimpleExtension):
def __init__(self, path, model, **kwargs):
super(SaveLoadParams, self).__init__(**kwargs)
self.path = path
self.model = model
def do_save(self):
with open(self.path, 'w') as f:
logger.info('Saving parameters to %s...'%self.path)
cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
def do_load(self):
try:
with open(self.path, 'r') as f:
logger.info('Loading parameters from %s...'%self.path)
self.model.set_parameter_values(cPickle.load(f))
except IOError:
logger.warn('No previous parameters found at %s...'%self.path)
def do(self, which_callback, *args):
if which_callback == 'before_training':
self.do_load()
else:
self.do_save()
|