From c9ba2abc7172b4657216e0fcc638098060d7f753 Mon Sep 17 00:00:00 2001 From: Thomas Mesnard Date: Wed, 23 Dec 2015 20:27:49 +0100 Subject: At least it compiles --- ctc.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'ctc.py') diff --git a/ctc.py b/ctc.py index f03313b..5f34b2c 100644 --- a/ctc.py +++ b/ctc.py @@ -1,5 +1,6 @@ import numpy +import theano from theano import tensor, scan from blocks.bricks import Brick @@ -65,6 +66,7 @@ class CTC(Brick): alpha_bar = tensor.switch(l_case2, alpha_bar + prev_alpha_2, alpha_bar) next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S)) next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha) + next_alpha = next_alpha * tensor.lt(tensor.arange(S)[None,:], (2*l_len+1)[:, None]) next_c = next_alpha.sum(axis=1) return next_alpha / next_c[:, None], next_c @@ -85,15 +87,15 @@ class CTC(Brick): C = probs.shape[2]-1 maxprob = probs.argmax(axis=2) + is_double = tensor.eq(maxprob[:-1], maxprob[1:]) + maxprob = tensor.switch(tensor.concatenate([tensor.zeros((1,B)), is_double]), + maxprob, C*tensor.ones_like(maxprob)) # returns two values : # label : (T x) T x B # label_length : (T x) B def recursion(maxp, p_mask, label_length, label): - label_length = label_length[-1] - label = label[-1] - - nonzero = p_mask * tensor.ne(maxp, C) + nonzero = p_mask * tensor.neq(maxp, C) nonzero_id = nonzero.nonzero()[0] new_label = tensor.set_subtensor(label[label_length[nonzero_id], nonzero_id], maxp[nonzero_id]) @@ -101,9 +103,9 @@ class CTC(Brick): return new_label_length, new_label - label_length, label = tensor.scan(fn=recursion, - sequences=[maxprob, probs_mask], - outputs_info=[tensor.zeros((B),dtype='int32'),tensor.zeros((T,B))]) + [label_length, label], _ = scan(fn=recursion, + sequences=[maxprob, probs_mask], + outputs_info=[tensor.zeros((B,),dtype='int32'),tensor.zeros((T,B))]) return label[-1], label_length[-1] -- cgit v1.2.3