From 448e848796757ad9f0a2f681886f868b8f22e81f Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 May 2015 15:51:02 -0400 Subject: Add ElementwiseRemoveNotFinite step rule. --- train.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 17fa612..9f636c0 100755 --- a/train.py +++ b/train.py @@ -7,8 +7,10 @@ import os import sys from functools import reduce +from theano import tensor + from blocks import roles -from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite +from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot @@ -21,6 +23,37 @@ from blocks.model import Model logger = logging.getLogger(__name__) + +class ElementwiseRemoveNotFinite(StepRule): + """A step rule that replaces non-finite coefficients by zeros. + + Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step + (the parameter update of a single shared variable) + with a scaled version of the parameters being updated instead. + + Parameters + ---------- + scaler : float, optional + The scaling applied to the parameter in case the step contains + non-finite elements. Defaults to 0.1. + + Notes + ----- + This trick was originally used in the GroundHog_ framework. + + .. _GroundHog: https://github.com/lisa-groundhog/GroundHog + + """ + def __init__(self, scaler=0.1): + self.scaler = scaler + + def compute_step(self, param, previous_step): + not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) + step = tensor.switch(not_finite, self.scaler * param, previous_step) + + return step, [] + + if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] @@ -66,7 +99,7 @@ if __name__ == "__main__": algorithm = GradientDescent( cost=cost, step_rule=CompositeRule([ - RemoveNotFinite(), + ElementwiseRemoveNotFinite(), AdaDelta(), #Momentum(learning_rate=config.learning_rate, momentum=config.momentum), ]), -- cgit v1.2.3