aboutsummaryrefslogtreecommitdiff
path: root/model/bidirectional_tgtcls.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:30:27 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:30:40 -0400
commitff49937eef024916ac4560ce0134d94006e9e2e5 (patch)
tree5b8faba873efc03151ee457c459510c8f64b65f3 /model/bidirectional_tgtcls.py
parentac49e3cb892e0278ea1d52afdc314322000fae27 (diff)
downloadtaxi-ff49937eef024916ac4560ce0134d94006e9e2e5.tar.gz
taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.zip
RNN & Bidir RNN refactoring (& fixes, maybe)
Diffstat (limited to 'model/bidirectional_tgtcls.py')
-rw-r--r--model/bidirectional_tgtcls.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/model/bidirectional_tgtcls.py b/model/bidirectional_tgtcls.py
index 36120f7..4dfbad5 100644
--- a/model/bidirectional_tgtcls.py
+++ b/model/bidirectional_tgtcls.py
@@ -11,9 +11,12 @@ class Model(BidiRNN):
@lazy()
def __init__(self, config, **kwargs):
super(Model, self).__init__(config, output_dim=config.tgtcls.shape[0], **kwargs)
- self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes')
+
+ self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX),
+ name='classes')
self.softmax = Softmax()
self.children.append(self.softmax)
def process_outputs(self, outputs):
return tensor.dot(self.softmax.apply(outputs), self.classes)
+