aboutsummaryrefslogtreecommitdiff
path: root/ctc.py
diff options
context:
space:
mode:
Diffstat (limited to 'ctc.py')
-rw-r--r--ctc.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/ctc.py b/ctc.py
index cf629c1..a0dab80 100644
--- a/ctc.py
+++ b/ctc.py
@@ -34,7 +34,7 @@ class CTC(Brick):
l_blk = tensor.zeros((S, B))
l_blk = tensor.set_subtensor(l_blk[1::2,:],l)
- # dimension of alpha :
+ # dimension of alpha (corresponds to alpha hat in the paper) :
# T x B x S
# dimension of c :
# T x B
@@ -44,7 +44,8 @@ class CTC(Brick):
probs[0][tensor.arange(B), l[0]],
tensor.zeros((B, S-2))
], axis=1)
- c0 = alpha0.sum(axis=2)
+ c0 = alpha0.sum(axis=1)
+ alpha0 = alpha0 / c0[:,None]
# recursion
def recursion(p, p_mask, prev_alpha, prev_c):