diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-19 15:59:09 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-19 15:59:09 -0400 |
commit | 7bf692d9ae344ccef044923f131f5ce8de85b0b4 (patch) | |
tree | 227af7599d2ef24bbc4bd6c4de5be384969fa139 /train.py | |
parent | de89d218940295f834523dfcfd6840965a63dda5 (diff) | |
download | text-rnn-7bf692d9ae344ccef044923f131f5ce8de85b0b4.tar.gz text-rnn-7bf692d9ae344ccef044923f131f5ce8de85b0b4.zip |
Something that does not really work
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 17 |
1 files changed, 15 insertions, 2 deletions
@@ -19,7 +19,7 @@ from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model -from blocks.algorithms import GradientDescent +from blocks.algorithms import GradientDescent, StepRule, CompositeRule import datastream from paramsaveload import SaveLoadParams @@ -38,6 +38,17 @@ if __name__ == "__main__": model_name = sys.argv[1] config = importlib.import_module('%s' % model_name) + +class ElementwiseRemoveNotFinite(StepRule): + def __init__(self, scaler=0.1): + self.scaler = scaler + + def compute_step(self, param, previous_step): + not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) + step = tensor.switch(not_finite, self.scaler * param, previous_step) + + return step, [] + class ResetStates(SimpleExtension): def __init__(self, state_vars, **kwargs): super(ResetStates, self).__init__(**kwargs) @@ -56,7 +67,9 @@ def train_model(m, train_stream, dump_path=None): cg = ComputationGraph(m.cost_reg) algorithm = GradientDescent(cost=m.cost_reg, - step_rule=config.step_rule, + step_rule=CompositeRule([ + ElementwiseRemoveNotFinite(), + config.step_rule]), params=cg.parameters) algorithm.add_updates(m.states) |