diff options
Diffstat (limited to 'ctc.py')
-rw-r--r-- | ctc.py | 8 |
1 files changed, 4 insertions, 4 deletions
@@ -7,7 +7,7 @@ from blocks.bricks import Brick # L: OUTPUT_SEQUENCE_LENGTH # C: NUM_CLASSES class CTC(Brick): - def apply(l, probs, l_len=None, probs_mask=None): + def apply(self, l, probs, l_len=None, probs_mask=None): """ Numeration: Characters 0 to C-1 are true characters @@ -63,7 +63,7 @@ class CTC(Brick): alphabar = prev_alpha + prev_alpha1 alphabar = tensor.switch(l_case2, alphabar + prev_alpha2, alphabar) next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S)) - next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha] + next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha) next_c = next_alpha.sum(axis=1) return next_alpha / next_c[:, None], next_c @@ -77,7 +77,7 @@ class CTC(Brick): return tensor.log(c).sum(axis=0) - def best_path_decoding(probs, probs_mask=None): + def best_path_decoding(self, probs, probs_mask=None): # probs is T x B x C+1 T = probs.shape[0] B = probs.shape[1] @@ -106,7 +106,7 @@ class CTC(Brick): return label[-1], label_length[-1] - def prefix_search(probs, probs_mask=None): + def prefix_search(self, probs, probs_mask=None): # Hard one... pass |