summaryrefslogblamecommitdiff
path: root/paramsaveload.py
blob: 345cfbb1e42b9fe5f2331786a8ba2f856a20d5ba (plain) (tree)




















                                                                           
                                                                                                             




                                                                                      
                                                                                
                               
                                                                                      






                                                       
import logging

import numpy

import cPickle

from blocks.extensions import SimpleExtension

logging.basicConfig(level='INFO')
logger = logging.getLogger('extensions.SaveLoadParams')

class SaveLoadParams(SimpleExtension):
	def __init__(self, path, model, **kwargs):
		super(SaveLoadParams, self).__init__(**kwargs)

		self.path = path
		self.model = model
	
	def do_save(self):
		with open(self.path, 'w') as f:
			logger.info('Saving parameters to %s...'%self.path)
			cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
	
	def do_load(self):
		try:
			with open(self.path, 'r') as f:
				logger.info('Loading parameters from %s...'%self.path)
				self.model.set_parameter_values(cPickle.load(f))
		except IOError:
			logger.warn('No previous parameters found at %s...'%self.path)

	def do(self, which_callback, *args):
		if which_callback == 'before_training':
			self.do_load()
		else:
			self.do_save()