From 08b63add743087ac0e2bbb7f739605642b3edc7b Mon Sep 17 00:00:00 2001 From: Thomas Mesnard Date: Wed, 23 Dec 2015 11:14:57 +0100 Subject: Add main.py --- ctc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'ctc.py') diff --git a/ctc.py b/ctc.py index c546cb8..57b2d36 100644 --- a/ctc.py +++ b/ctc.py @@ -7,7 +7,7 @@ from blocks.bricks import Brick # L: OUTPUT_SEQUENCE_LENGTH # C: NUM_CLASSES class CTC(Brick): - def apply(l, probs, l_len=None, probs_mask=None): + def apply(self, l, probs, l_len=None, probs_mask=None): """ Numeration: Characters 0 to C-1 are true characters @@ -63,7 +63,7 @@ class CTC(Brick): alphabar = prev_alpha + prev_alpha1 alphabar = tensor.switch(l_case2, alphabar + prev_alpha2, alphabar) next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S)) - next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha] + next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha) next_c = next_alpha.sum(axis=1) return next_alpha / next_c[:, None], next_c @@ -77,7 +77,7 @@ class CTC(Brick): return tensor.log(c).sum(axis=0) - def best_path_decoding(probs, probs_mask=None): + def best_path_decoding(self, probs, probs_mask=None): # probs is T x B x C+1 T = probs.shape[0] B = probs.shape[1] @@ -106,7 +106,7 @@ class CTC(Brick): return label[-1], label_length[-1] - def prefix_search(probs, probs_mask=None): + def prefix_search(self, probs, probs_mask=None): # Hard one... pass -- cgit v1.2.3