diff options
author | Thomas Mesnard <thomas.mesnard@ens.fr> | 2015-12-23 11:14:57 +0100 |
---|---|---|
committer | Thomas Mesnard <thomas.mesnard@ens.fr> | 2015-12-23 11:14:57 +0100 |
commit | 08b63add743087ac0e2bbb7f739605642b3edc7b (patch) | |
tree | b253731214ab1c12fd7c8579fbdaad6aa99feebf /ctc.py | |
parent | 694964422eef7b835c1bfa3643fdee3bc1cffdd7 (diff) | |
download | pgm-ctc-08b63add743087ac0e2bbb7f739605642b3edc7b.tar.gz pgm-ctc-08b63add743087ac0e2bbb7f739605642b3edc7b.zip |
Add main.py
Diffstat (limited to 'ctc.py')
-rw-r--r-- | ctc.py | 8 |
1 files changed, 4 insertions, 4 deletions
@@ -7,7 +7,7 @@ from blocks.bricks import Brick # L: OUTPUT_SEQUENCE_LENGTH # C: NUM_CLASSES class CTC(Brick): - def apply(l, probs, l_len=None, probs_mask=None): + def apply(self, l, probs, l_len=None, probs_mask=None): """ Numeration: Characters 0 to C-1 are true characters @@ -63,7 +63,7 @@ class CTC(Brick): alphabar = prev_alpha + prev_alpha1 alphabar = tensor.switch(l_case2, alphabar + prev_alpha2, alphabar) next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S)) - next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha] + next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha) next_c = next_alpha.sum(axis=1) return next_alpha / next_c[:, None], next_c @@ -77,7 +77,7 @@ class CTC(Brick): return tensor.log(c).sum(axis=0) - def best_path_decoding(probs, probs_mask=None): + def best_path_decoding(self, probs, probs_mask=None): # probs is T x B x C+1 T = probs.shape[0] B = probs.shape[1] @@ -106,7 +106,7 @@ class CTC(Brick): return label[-1], label_length[-1] - def prefix_search(probs, probs_mask=None): + def prefix_search(self, probs, probs_mask=None): # Hard one... pass |