diff options
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 17 |
1 files changed, 15 insertions, 2 deletions
@@ -19,7 +19,7 @@ from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model -from blocks.algorithms import GradientDescent +from blocks.algorithms import GradientDescent, StepRule, CompositeRule import datastream from paramsaveload import SaveLoadParams @@ -38,6 +38,17 @@ if __name__ == "__main__": model_name = sys.argv[1] config = importlib.import_module('%s' % model_name) + +class ElementwiseRemoveNotFinite(StepRule): + def __init__(self, scaler=0.1): + self.scaler = scaler + + def compute_step(self, param, previous_step): + not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) + step = tensor.switch(not_finite, self.scaler * param, previous_step) + + return step, [] + class ResetStates(SimpleExtension): def __init__(self, state_vars, **kwargs): super(ResetStates, self).__init__(**kwargs) @@ -56,7 +67,9 @@ def train_model(m, train_stream, dump_path=None): cg = ComputationGraph(m.cost_reg) algorithm = GradientDescent(cost=m.cost_reg, - step_rule=config.step_rule, + step_rule=CompositeRule([ + ElementwiseRemoveNotFinite(), + config.step_rule]), params=cg.parameters) algorithm.add_updates(m.states) |