diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-22 15:51:02 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-22 15:51:47 -0400 |
commit | 448e848796757ad9f0a2f681886f868b8f22e81f (patch) | |
tree | 89125e3e56de2147f2477732fdc52ea19f2470e3 | |
parent | c2c88c48a0404de0eb834df71fa53ae63fdfd1c7 (diff) | |
download | taxi-448e848796757ad9f0a2f681886f868b8f22e81f.tar.gz taxi-448e848796757ad9f0a2f681886f868b8f22e81f.zip |
Add ElementwiseRemoveNotFinite step rule.
-rwxr-xr-x | train.py | 37 |
1 files changed, 35 insertions, 2 deletions
@@ -7,8 +7,10 @@ import os import sys from functools import reduce +from theano import tensor + from blocks import roles -from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite +from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot @@ -21,6 +23,37 @@ from blocks.model import Model logger = logging.getLogger(__name__) + +class ElementwiseRemoveNotFinite(StepRule): + """A step rule that replaces non-finite coefficients by zeros. + + Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step + (the parameter update of a single shared variable) + with a scaled version of the parameters being updated instead. + + Parameters + ---------- + scaler : float, optional + The scaling applied to the parameter in case the step contains + non-finite elements. Defaults to 0.1. + + Notes + ----- + This trick was originally used in the GroundHog_ framework. + + .. _GroundHog: https://github.com/lisa-groundhog/GroundHog + + """ + def __init__(self, scaler=0.1): + self.scaler = scaler + + def compute_step(self, param, previous_step): + not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) + step = tensor.switch(not_finite, self.scaler * param, previous_step) + + return step, [] + + if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] @@ -66,7 +99,7 @@ if __name__ == "__main__": algorithm = GradientDescent( cost=cost, step_rule=CompositeRule([ - RemoveNotFinite(), + ElementwiseRemoveNotFinite(), AdaDelta(), #Momentum(learning_rate=config.learning_rate, momentum=config.momentum), ]), |