From 7bf692d9ae344ccef044923f131f5ce8de85b0b4 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 19 Jun 2015 15:59:09 -0400 Subject: Something that does not really work --- train.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 3ac24e7..61f6663 100755 --- a/train.py +++ b/train.py @@ -19,7 +19,7 @@ from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model -from blocks.algorithms import GradientDescent +from blocks.algorithms import GradientDescent, StepRule, CompositeRule import datastream from paramsaveload import SaveLoadParams @@ -38,6 +38,17 @@ if __name__ == "__main__": model_name = sys.argv[1] config = importlib.import_module('%s' % model_name) + +class ElementwiseRemoveNotFinite(StepRule): + def __init__(self, scaler=0.1): + self.scaler = scaler + + def compute_step(self, param, previous_step): + not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) + step = tensor.switch(not_finite, self.scaler * param, previous_step) + + return step, [] + class ResetStates(SimpleExtension): def __init__(self, state_vars, **kwargs): super(ResetStates, self).__init__(**kwargs) @@ -56,7 +67,9 @@ def train_model(m, train_stream, dump_path=None): cg = ComputationGraph(m.cost_reg) algorithm = GradientDescent(cost=m.cost_reg, - step_rule=config.step_rule, + step_rule=CompositeRule([ + ElementwiseRemoveNotFinite(), + config.step_rule]), params=cg.parameters) algorithm.add_updates(m.states) -- cgit v1.2.3