aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 15:51:02 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 15:51:47 -0400
commit448e848796757ad9f0a2f681886f868b8f22e81f (patch)
tree89125e3e56de2147f2477732fdc52ea19f2470e3
parentc2c88c48a0404de0eb834df71fa53ae63fdfd1c7 (diff)
downloadtaxi-448e848796757ad9f0a2f681886f868b8f22e81f.tar.gz
taxi-448e848796757ad9f0a2f681886f868b8f22e81f.zip
Add ElementwiseRemoveNotFinite step rule.
-rwxr-xr-xtrain.py37
1 files changed, 35 insertions, 2 deletions
diff --git a/train.py b/train.py
index 17fa612..9f636c0 100755
--- a/train.py
+++ b/train.py
@@ -7,8 +7,10 @@ import os
import sys
from functools import reduce
+from theano import tensor
+
from blocks import roles
-from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite
+from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.extensions.plot import Plot
@@ -21,6 +23,37 @@ from blocks.model import Model
logger = logging.getLogger(__name__)
+
+class ElementwiseRemoveNotFinite(StepRule):
+ """A step rule that replaces non-finite coefficients by zeros.
+
+ Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step
+ (the parameter update of a single shared variable)
+ with a scaled version of the parameters being updated instead.
+
+ Parameters
+ ----------
+ scaler : float, optional
+ The scaling applied to the parameter in case the step contains
+ non-finite elements. Defaults to 0.1.
+
+ Notes
+ -----
+ This trick was originally used in the GroundHog_ framework.
+
+ .. _GroundHog: https://github.com/lisa-groundhog/GroundHog
+
+ """
+ def __init__(self, scaler=0.1):
+ self.scaler = scaler
+
+ def compute_step(self, param, previous_step):
+ not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step)
+ step = tensor.switch(not_finite, self.scaler * param, previous_step)
+
+ return step, []
+
+
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
@@ -66,7 +99,7 @@ if __name__ == "__main__":
algorithm = GradientDescent(
cost=cost,
step_rule=CompositeRule([
- RemoveNotFinite(),
+ ElementwiseRemoveNotFinite(),
AdaDelta(),
#Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
]),