aboutsummaryrefslogtreecommitdiff
path: root/ext_param_info.py
blob: a60f9d74a5056b4f55328284d1ad0e688b7e7498 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import logging

import numpy

import cPickle

from blocks.extensions import SimpleExtension

logging.basicConfig(level='INFO')
logger = logging.getLogger('extensions.ParamInfo')

class ParamInfo(SimpleExtension):
	def __init__(self, model, **kwargs):
		super(ParamInfo, self).__init__(**kwargs)

		self.model = model
	
	def do(self, which_callback, *args):
		print("---- PARAMETER INFO ----")
		print("\tmin\tmax\tmean\tvar\tdim\t\tname")
		for k, v in self.model.get_parameter_values().iteritems():
			print("\t%.4f\t%.4f\t%.4f\t%.4f\t%13s\t%s"%
					(v.min(), v.max(), v.mean(), ((v-v.mean())**2).mean(), 'x'.join([repr(x) for x in v.shape]), k))