aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAdeB <adbrebs@gmail.com>2015-06-24 15:12:15 -0400
committerAdeB <adbrebs@gmail.com>2015-06-24 15:12:15 -0400
commit88cdc3f8047a05bc5971eaa915ca6626f89a3e78 (patch)
treeaf9bc201cf442588492316b2360bd0bd16c8b843 /train.py
parentbd08e452093bba68fe2d79b1e9da76488b203720 (diff)
downloadtaxi-88cdc3f8047a05bc5971eaa915ca6626f89a3e78.tar.gz
taxi-88cdc3f8047a05bc5971eaa915ca6626f89a3e78.zip
New configs. training step rule out of train.py
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/train.py b/train.py
index 1b75138..83317c9 100755
--- a/train.py
+++ b/train.py
@@ -11,7 +11,7 @@ from functools import reduce
from theano import tensor
from blocks import roles
-from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule
+from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum
from blocks.extensions import Printing, FinishAfter, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
@@ -136,8 +136,7 @@ if __name__ == "__main__":
cost=cost,
step_rule=CompositeRule([
ElementwiseRemoveNotFinite(),
- AdaDelta(),
- #Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
+ config.step_rule,
]),
params=params)