summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-19 15:59:09 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-19 15:59:09 -0400
commit7bf692d9ae344ccef044923f131f5ce8de85b0b4 (patch)
tree227af7599d2ef24bbc4bd6c4de5be384969fa139 /train.py
parentde89d218940295f834523dfcfd6840965a63dda5 (diff)
downloadtext-rnn-7bf692d9ae344ccef044923f131f5ce8de85b0b4.tar.gz
text-rnn-7bf692d9ae344ccef044923f131f5ce8de85b0b4.zip
Something that does not really work
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/train.py b/train.py
index 3ac24e7..61f6663 100755
--- a/train.py
+++ b/train.py
@@ -19,7 +19,7 @@ from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
-from blocks.algorithms import GradientDescent
+from blocks.algorithms import GradientDescent, StepRule, CompositeRule
import datastream
from paramsaveload import SaveLoadParams
@@ -38,6 +38,17 @@ if __name__ == "__main__":
model_name = sys.argv[1]
config = importlib.import_module('%s' % model_name)
+
+class ElementwiseRemoveNotFinite(StepRule):
+ def __init__(self, scaler=0.1):
+ self.scaler = scaler
+
+ def compute_step(self, param, previous_step):
+ not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step)
+ step = tensor.switch(not_finite, self.scaler * param, previous_step)
+
+ return step, []
+
class ResetStates(SimpleExtension):
def __init__(self, state_vars, **kwargs):
super(ResetStates, self).__init__(**kwargs)
@@ -56,7 +67,9 @@ def train_model(m, train_stream, dump_path=None):
cg = ComputationGraph(m.cost_reg)
algorithm = GradientDescent(cost=m.cost_reg,
- step_rule=config.step_rule,
+ step_rule=CompositeRule([
+ ElementwiseRemoveNotFinite(),
+ config.step_rule]),
params=cg.parameters)
algorithm.add_updates(m.states)