aboutsummaryrefslogtreecommitdiff
path: root/ctc.py
diff options
context:
space:
mode:
Diffstat (limited to 'ctc.py')
-rw-r--r--ctc.py23
1 files changed, 12 insertions, 11 deletions
diff --git a/ctc.py b/ctc.py
index 5f34b2c..734e354 100644
--- a/ctc.py
+++ b/ctc.py
@@ -28,14 +28,14 @@ class CTC(Brick):
- Return the probability found at the end of that sequence
"""
T = probs.shape[0]
+ B = probs.shape[1]
C = probs.shape[2]-1
L = l.shape[0]
S = 2*L+1
- B = l.shape[1]
# l_blk = l with interleaved blanks
l_blk = C * tensor.ones((S, B), dtype='int32')
- l_blk = tensor.set_subtensor(l_blk[1::2,:],l)
+ l_blk = tensor.set_subtensor(l_blk[1::2,:], l)
l_blk = l_blk.T # now l_blk is B x S
# dimension of alpha (corresponds to alpha hat in the paper) :
@@ -43,13 +43,10 @@ class CTC(Brick):
# dimension of c :
# T x B
# first value of alpha (size B x S)
- alpha0 = tensor.concatenate([
- probs[0, :, C][:,None],
- probs[0][tensor.arange(B), l[0]][:,None],
- tensor.zeros((B, S-2))
+ alpha0 = tensor.concatenate([ tensor.ones((B, 1)),
+ tensor.zeros((B, S-1))
], axis=1)
- c0 = alpha0.sum(axis=1)
- alpha0 = alpha0 / c0[:,None]
+ c0 = tensor.ones((B,))
# recursion
l_blk_2 = tensor.concatenate([-tensor.ones((B,2)), l_blk[:,:-2]], axis=1)
@@ -76,8 +73,11 @@ class CTC(Brick):
sequences=[probs, probs_mask],
outputs_info=[alpha0, c0])
+ prob = tensor.log(c).sum(axis=0) + tensor.log(alpha[-1][tensor.arange(B), 2*l_len.astype('int32')-1]
+ + alpha[-1][tensor.arange(B), 2*l_len.astype('int32')])
+
# return the log probability of the labellings
- return tensor.log(c).sum(axis=0)
+ return -prob
def best_path_decoding(self, probs, probs_mask=None):
@@ -89,7 +89,8 @@ class CTC(Brick):
maxprob = probs.argmax(axis=2)
is_double = tensor.eq(maxprob[:-1], maxprob[1:])
maxprob = tensor.switch(tensor.concatenate([tensor.zeros((1,B)), is_double]),
- maxprob, C*tensor.ones_like(maxprob))
+ C*tensor.ones_like(maxprob), maxprob)
+ # maxprob = theano.printing.Print('maxprob')(maxprob.T).T
# returns two values :
# label : (T x) T x B
@@ -105,7 +106,7 @@ class CTC(Brick):
[label_length, label], _ = scan(fn=recursion,
sequences=[maxprob, probs_mask],
- outputs_info=[tensor.zeros((B,),dtype='int32'),tensor.zeros((T,B))])
+ outputs_info=[tensor.zeros((B,),dtype='int32'),-tensor.ones((T,B))])
return label[-1], label_length[-1]