aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-13 16:29:25 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-13 16:30:12 -0400
commit9ff3d163609707c0138c0de731eec40449bd1815 (patch)
tree4fcce51242237448e3a798053cdcbf2faf2ea568
parent8470f64c9373308d7f85de0f7de3bdcbaf46ca0a (diff)
downloadtaxi-9ff3d163609707c0138c0de731eec40449bd1815.tar.gz
taxi-9ff3d163609707c0138c0de731eec40449bd1815.zip
Add support for dropout in joint model
-rw-r--r--model/joint_simple_mlp_tgtcls.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/model/joint_simple_mlp_tgtcls.py b/model/joint_simple_mlp_tgtcls.py
index dd1242e..834afbf 100644
--- a/model/joint_simple_mlp_tgtcls.py
+++ b/model/joint_simple_mlp_tgtcls.py
@@ -1,6 +1,9 @@
from blocks.bricks import MLP, Rectifier, Linear, Sigmoid, Identity, Softmax
from blocks.bricks.lookup import LookupTable
+from blocks.filter import VariableFilter
+from blocks.graph import ComputationGraph, apply_dropout
+
import numpy
import theano
from theano import tensor
@@ -74,6 +77,18 @@ class Model(object):
time_scost.name = 'time_scost'
cost = dest_cost + time_scost
+
+ if hasattr(config, 'dropout_p'):
+ cg = ComputationGraph(cost)
+ dropout_inputs = VariableFilter(
+ bricks=[b for b in list(common_mlp.children) +
+ list(dest_mlp.children) +
+ list(time_mlp.children)
+ if isinstance(b, Rectifier)],
+ name='output')(cg)
+ cg = apply_dropout(cg, dropout_inputs, config.dropout_p)
+ cost = cg.outputs[0]
+
cost.name = 'cost'
# Initialization