aboutsummaryrefslogtreecommitdiff
path: root/ctc.py
diff options
context:
space:
mode:
authorThomas Mesnard <thomas.mesnard@ens.fr>2015-12-23 11:14:57 +0100
committerThomas Mesnard <thomas.mesnard@ens.fr>2015-12-23 11:14:57 +0100
commit08b63add743087ac0e2bbb7f739605642b3edc7b (patch)
treeb253731214ab1c12fd7c8579fbdaad6aa99feebf /ctc.py
parent694964422eef7b835c1bfa3643fdee3bc1cffdd7 (diff)
downloadpgm-ctc-08b63add743087ac0e2bbb7f739605642b3edc7b.tar.gz
pgm-ctc-08b63add743087ac0e2bbb7f739605642b3edc7b.zip
Add main.py
Diffstat (limited to 'ctc.py')
-rw-r--r--ctc.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ctc.py b/ctc.py
index c546cb8..57b2d36 100644
--- a/ctc.py
+++ b/ctc.py
@@ -7,7 +7,7 @@ from blocks.bricks import Brick
# L: OUTPUT_SEQUENCE_LENGTH
# C: NUM_CLASSES
class CTC(Brick):
- def apply(l, probs, l_len=None, probs_mask=None):
+ def apply(self, l, probs, l_len=None, probs_mask=None):
"""
Numeration:
Characters 0 to C-1 are true characters
@@ -63,7 +63,7 @@ class CTC(Brick):
alphabar = prev_alpha + prev_alpha1
alphabar = tensor.switch(l_case2, alphabar + prev_alpha2, alphabar)
next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S))
- next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha]
+ next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha)
next_c = next_alpha.sum(axis=1)
return next_alpha / next_c[:, None], next_c
@@ -77,7 +77,7 @@ class CTC(Brick):
return tensor.log(c).sum(axis=0)
- def best_path_decoding(probs, probs_mask=None):
+ def best_path_decoding(self, probs, probs_mask=None):
# probs is T x B x C+1
T = probs.shape[0]
B = probs.shape[1]
@@ -106,7 +106,7 @@ class CTC(Brick):
return label[-1], label_length[-1]
- def prefix_search(probs, probs_mask=None):
+ def prefix_search(self, probs, probs_mask=None):
# Hard one...
pass