aboutsummaryrefslogblamecommitdiff
path: root/ext_param_info.py
blob: a60f9d74a5056b4f55328284d1ad0e688b7e7498 (plain) (tree)























                                                                                                                                        
import logging

import numpy

import cPickle

from blocks.extensions import SimpleExtension

logging.basicConfig(level='INFO')
logger = logging.getLogger('extensions.ParamInfo')

class ParamInfo(SimpleExtension):
	def __init__(self, model, **kwargs):
		super(ParamInfo, self).__init__(**kwargs)

		self.model = model
	
	def do(self, which_callback, *args):
		print("---- PARAMETER INFO ----")
		print("\tmin\tmax\tmean\tvar\tdim\t\tname")
		for k, v in self.model.get_parameter_values().iteritems():
			print("\t%.4f\t%.4f\t%.4f\t%.4f\t%13s\t%s"%
					(v.min(), v.max(), v.mean(), ((v-v.mean())**2).mean(), 'x'.join([repr(x) for x in v.shape]), k))