aboutsummaryrefslogtreecommitdiff
path: root/ext_param_info.py
diff options
context:
space:
mode:
Diffstat (limited to 'ext_param_info.py')
-rw-r--r--ext_param_info.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/ext_param_info.py b/ext_param_info.py
new file mode 100644
index 0000000..a60f9d7
--- /dev/null
+++ b/ext_param_info.py
@@ -0,0 +1,24 @@
+import logging
+
+import numpy
+
+import cPickle
+
+from blocks.extensions import SimpleExtension
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger('extensions.ParamInfo')
+
+class ParamInfo(SimpleExtension):
+ def __init__(self, model, **kwargs):
+ super(ParamInfo, self).__init__(**kwargs)
+
+ self.model = model
+
+ def do(self, which_callback, *args):
+ print("---- PARAMETER INFO ----")
+ print("\tmin\tmax\tmean\tvar\tdim\t\tname")
+ for k, v in self.model.get_parameter_values().iteritems():
+ print("\t%.4f\t%.4f\t%.4f\t%.4f\t%13s\t%s"%
+ (v.min(), v.max(), v.mean(), ((v-v.mean())**2).mean(), 'x'.join([repr(x) for x in v.shape]), k))
+