diff options
-rw-r--r-- | ctc.py | 30 |
1 files changed, 23 insertions, 7 deletions
@@ -7,7 +7,7 @@ from blocks.bricks import Brick # L: OUTPUT_SEQUENCE_LENGTH # C: NUM_CLASSES class CTC(Brick): - def apply(l, probs, l_mask=None, probs_mask=None): + def apply(l, probs, l_len=None, probs_mask=None): """ Numeration: Characters 0 to C-1 are true characters @@ -15,7 +15,7 @@ class CTC(Brick): Inputs: l : L x B : the sequence labelling probs : T x B x C+1 : the probabilities output by the RNN - l_mask : L x B + l_len : B : the length of each labelling sequence probs_mask : T x B Output: the B probabilities of the labelling sequences Steps: @@ -31,8 +31,9 @@ class CTC(Brick): B = l.shape[1] # l_blk = l with interleaved blanks - l_blk = tensor.zeros((S, B)) + l_blk = C * tensor.ones((S, B)) l_blk = tensor.set_subtensor(l_blk[1::2,:],l) + l_blk = l_blk.T # now l_blk is B x S # dimension of alpha (corresponds to alpha hat in the paper) : # T x B x S @@ -48,17 +49,32 @@ class CTC(Brick): alpha0 = alpha0 / c0[:,None] # recursion + l_blk_2 = tensor.concatenate([-tensor.ones((B,2)), l_blk[:,:-2]], axis=1) + l_case2 = tensor.ne(l_blk, numpy.float32(C)) * tensor.ne(l_blk, l_blk_2) + # l_case2 is B x S + def recursion(p, p_mask, prev_alpha, prev_c): - # TODO - return prev_alpha[-1], prev_c[-1] + prev_alpha = prev_alpha[-1] + # p is B x C+1 + # prev_alpha is B x S + prev_alpha_1 = tensor.concatenate([tensor.zeros((B,1)),prev_alpha[:,:-1]], axis=1) + prev_alpha_2 = tensor.concatenate([tensor.zeros((B,2)),prev_alpha[:,:-2]], axis=1) + + alphabar = prev_alpha + prev_alpha1 + alphabar = tensor.switch(l_case2, alphabar + prev_alpha2, alphabar) + next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S)) + next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha] + next_c = next_alpha.sum(axis=1) + + return next_alpha / next_c[:, None], next_c # apply the recursion with scan alpha, c = tensor.scan(fn=recursion, sequences=[probs, probs_mask], outputs_info=[alpha0, c0]) - # return the probability of the labellings - + # return the log probability of the labellings + return tensor.log(c).sum(axis=0) def best_path_decoding(y_hat, y_hat_mask=None): |