summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/train.py b/train.py
index 3ac24e7..61f6663 100755
--- a/train.py
+++ b/train.py
@@ -19,7 +19,7 @@ from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
-from blocks.algorithms import GradientDescent
+from blocks.algorithms import GradientDescent, StepRule, CompositeRule
import datastream
from paramsaveload import SaveLoadParams
@@ -38,6 +38,17 @@ if __name__ == "__main__":
model_name = sys.argv[1]
config = importlib.import_module('%s' % model_name)
+
+class ElementwiseRemoveNotFinite(StepRule):
+ def __init__(self, scaler=0.1):
+ self.scaler = scaler
+
+ def compute_step(self, param, previous_step):
+ not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step)
+ step = tensor.switch(not_finite, self.scaler * param, previous_step)
+
+ return step, []
+
class ResetStates(SimpleExtension):
def __init__(self, state_vars, **kwargs):
super(ResetStates, self).__init__(**kwargs)
@@ -56,7 +67,9 @@ def train_model(m, train_stream, dump_path=None):
cg = ComputationGraph(m.cost_reg)
algorithm = GradientDescent(cost=m.cost_reg,
- step_rule=config.step_rule,
+ step_rule=CompositeRule([
+ ElementwiseRemoveNotFinite(),
+ config.step_rule]),
params=cg.parameters)
algorithm.add_updates(m.states)