From 694964422eef7b835c1bfa3643fdee3bc1cffdd7 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 21 Dec 2015 11:56:51 +0100 Subject: Best path decoding fist implementation (not verified) --- ctc.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) (limited to 'ctc.py') diff --git a/ctc.py b/ctc.py index 0ac620b..c546cb8 100644 --- a/ctc.py +++ b/ctc.py @@ -77,11 +77,36 @@ class CTC(Brick): return tensor.log(c).sum(axis=0) - def best_path_decoding(y_hat, y_hat_mask=None): - # Easy one ! - pass + def best_path_decoding(probs, probs_mask=None): + # probs is T x B x C+1 + T = probs.shape[0] + B = probs.shape[1] + C = probs.shape[2]-1 + + maxprob = probs.argmax(axis=2) + + # returns two values : + # label : (T x) T x B + # label_length : (T x) B + def recursion(maxp, p_mask, label_length, label): + label_length = label_length[-1] + label = label[-1] + + nonzero = p_mask * tensor.ne(maxp, C) + nonzero_id = nonzero.nonzero()[0] + + new_label = tensor.set_subtensor(label[label_length[nonzero_id], nonzero_id], maxp[nonzero_id]) + new_label_length = tensor.switch(nonzero, label_length + numpy.int32(1), label_length) + + return new_label_length, new_label + + label_length, label = tensor.scan(fn=recursion, + sequences=[maxprob, probs_mask], + outputs_info=[tensor.zeros((B),dtype='int32'),tensor.zeros((T,B))]) + + return label[-1], label_length[-1] - def prefix_search(y_hat, y_hat_mask=None): + def prefix_search(probs, probs_mask=None): # Hard one... pass -- cgit v1.2.3