aboutsummaryrefslogtreecommitdiff
path: root/model/rnn_tgtcls.py
blob: 0c0faf2e99e0a56790564261dfc91721ec9b8874 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy
import theano
from theano import tensor
from blocks.bricks.base import lazy
from blocks.bricks import Softmax

from model.rnn import RNN, Stream


class Model(RNN):
    @lazy()
    def __init__(self, config, **kwargs):
        super(Model, self).__init__(config, output_dim=config.tgtcls.shape[0], **kwargs)
        self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes')
        self.softmax = Softmax()
        self.children.append(self.softmax)

    def process_rto(self, rto):
        return tensor.dot(self.softmax.apply(rto), self.classes)