diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-13 16:29:25 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-13 16:30:12 -0400 |
commit | 9ff3d163609707c0138c0de731eec40449bd1815 (patch) | |
tree | 4fcce51242237448e3a798053cdcbf2faf2ea568 /model/joint_simple_mlp_tgtcls.py | |
parent | 8470f64c9373308d7f85de0f7de3bdcbaf46ca0a (diff) | |
download | taxi-9ff3d163609707c0138c0de731eec40449bd1815.tar.gz taxi-9ff3d163609707c0138c0de731eec40449bd1815.zip |
Add support for dropout in joint model
Diffstat (limited to 'model/joint_simple_mlp_tgtcls.py')
-rw-r--r-- | model/joint_simple_mlp_tgtcls.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/model/joint_simple_mlp_tgtcls.py b/model/joint_simple_mlp_tgtcls.py index dd1242e..834afbf 100644 --- a/model/joint_simple_mlp_tgtcls.py +++ b/model/joint_simple_mlp_tgtcls.py @@ -1,6 +1,9 @@ from blocks.bricks import MLP, Rectifier, Linear, Sigmoid, Identity, Softmax from blocks.bricks.lookup import LookupTable +from blocks.filter import VariableFilter +from blocks.graph import ComputationGraph, apply_dropout + import numpy import theano from theano import tensor @@ -74,6 +77,18 @@ class Model(object): time_scost.name = 'time_scost' cost = dest_cost + time_scost + + if hasattr(config, 'dropout_p'): + cg = ComputationGraph(cost) + dropout_inputs = VariableFilter( + bricks=[b for b in list(common_mlp.children) + + list(dest_mlp.children) + + list(time_mlp.children) + if isinstance(b, Rectifier)], + name='output')(cg) + cg = apply_dropout(cg, dropout_inputs, config.dropout_p) + cost = cg.outputs[0] + cost.name = 'cost' # Initialization |