aboutsummaryrefslogtreecommitdiff
path: root/ctc.py
diff options
context:
space:
mode:
Diffstat (limited to 'ctc.py')
-rw-r--r--ctc.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/ctc.py b/ctc.py
index f03313b..5f34b2c 100644
--- a/ctc.py
+++ b/ctc.py
@@ -1,5 +1,6 @@
import numpy
+import theano
from theano import tensor, scan
from blocks.bricks import Brick
@@ -65,6 +66,7 @@ class CTC(Brick):
alpha_bar = tensor.switch(l_case2, alpha_bar + prev_alpha_2, alpha_bar)
next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S))
next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha)
+ next_alpha = next_alpha * tensor.lt(tensor.arange(S)[None,:], (2*l_len+1)[:, None])
next_c = next_alpha.sum(axis=1)
return next_alpha / next_c[:, None], next_c
@@ -85,15 +87,15 @@ class CTC(Brick):
C = probs.shape[2]-1
maxprob = probs.argmax(axis=2)
+ is_double = tensor.eq(maxprob[:-1], maxprob[1:])
+ maxprob = tensor.switch(tensor.concatenate([tensor.zeros((1,B)), is_double]),
+ maxprob, C*tensor.ones_like(maxprob))
# returns two values :
# label : (T x) T x B
# label_length : (T x) B
def recursion(maxp, p_mask, label_length, label):
- label_length = label_length[-1]
- label = label[-1]
-
- nonzero = p_mask * tensor.ne(maxp, C)
+ nonzero = p_mask * tensor.neq(maxp, C)
nonzero_id = nonzero.nonzero()[0]
new_label = tensor.set_subtensor(label[label_length[nonzero_id], nonzero_id], maxp[nonzero_id])
@@ -101,9 +103,9 @@ class CTC(Brick):
return new_label_length, new_label
- label_length, label = tensor.scan(fn=recursion,
- sequences=[maxprob, probs_mask],
- outputs_info=[tensor.zeros((B),dtype='int32'),tensor.zeros((T,B))])
+ [label_length, label], _ = scan(fn=recursion,
+ sequences=[maxprob, probs_mask],
+ outputs_info=[tensor.zeros((B,),dtype='int32'),tensor.zeros((T,B))])
return label[-1], label_length[-1]