diff options
author | Thomas Mesnard <thomas.mesnard@ens.fr> | 2015-12-28 20:35:38 +0100 |
---|---|---|
committer | Thomas Mesnard <thomas.mesnard@ens.fr> | 2015-12-28 20:35:38 +0100 |
commit | e8e37dee0c5c846b1aa2dd24dc99095191f72a9b (patch) | |
tree | d033f04eaca8178ada7ee966c4d8e56df45a6ace /ctc.py | |
parent | c9ba2abc7172b4657216e0fcc638098060d7f753 (diff) | |
download | pgm-ctc-e8e37dee0c5c846b1aa2dd24dc99095191f72a9b.tar.gz pgm-ctc-e8e37dee0c5c846b1aa2dd24dc99095191f72a9b.zip |
Kind of works
Diffstat (limited to 'ctc.py')
-rw-r--r-- | ctc.py | 23 |
1 files changed, 12 insertions, 11 deletions
@@ -28,14 +28,14 @@ class CTC(Brick): - Return the probability found at the end of that sequence """ T = probs.shape[0] + B = probs.shape[1] C = probs.shape[2]-1 L = l.shape[0] S = 2*L+1 - B = l.shape[1] # l_blk = l with interleaved blanks l_blk = C * tensor.ones((S, B), dtype='int32') - l_blk = tensor.set_subtensor(l_blk[1::2,:],l) + l_blk = tensor.set_subtensor(l_blk[1::2,:], l) l_blk = l_blk.T # now l_blk is B x S # dimension of alpha (corresponds to alpha hat in the paper) : @@ -43,13 +43,10 @@ class CTC(Brick): # dimension of c : # T x B # first value of alpha (size B x S) - alpha0 = tensor.concatenate([ - probs[0, :, C][:,None], - probs[0][tensor.arange(B), l[0]][:,None], - tensor.zeros((B, S-2)) + alpha0 = tensor.concatenate([ tensor.ones((B, 1)), + tensor.zeros((B, S-1)) ], axis=1) - c0 = alpha0.sum(axis=1) - alpha0 = alpha0 / c0[:,None] + c0 = tensor.ones((B,)) # recursion l_blk_2 = tensor.concatenate([-tensor.ones((B,2)), l_blk[:,:-2]], axis=1) @@ -76,8 +73,11 @@ class CTC(Brick): sequences=[probs, probs_mask], outputs_info=[alpha0, c0]) + prob = tensor.log(c).sum(axis=0) + tensor.log(alpha[-1][tensor.arange(B), 2*l_len.astype('int32')-1] + + alpha[-1][tensor.arange(B), 2*l_len.astype('int32')]) + # return the log probability of the labellings - return tensor.log(c).sum(axis=0) + return -prob def best_path_decoding(self, probs, probs_mask=None): @@ -89,7 +89,8 @@ class CTC(Brick): maxprob = probs.argmax(axis=2) is_double = tensor.eq(maxprob[:-1], maxprob[1:]) maxprob = tensor.switch(tensor.concatenate([tensor.zeros((1,B)), is_double]), - maxprob, C*tensor.ones_like(maxprob)) + C*tensor.ones_like(maxprob), maxprob) + # maxprob = theano.printing.Print('maxprob')(maxprob.T).T # returns two values : # label : (T x) T x B @@ -105,7 +106,7 @@ class CTC(Brick): [label_length, label], _ = scan(fn=recursion, sequences=[maxprob, probs_mask], - outputs_info=[tensor.zeros((B,),dtype='int32'),tensor.zeros((T,B))]) + outputs_info=[tensor.zeros((B,),dtype='int32'),-tensor.ones((T,B))]) return label[-1], label_length[-1] |