aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ctc.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/ctc.py b/ctc.py
index 0ac620b..c546cb8 100644
--- a/ctc.py
+++ b/ctc.py
@@ -77,11 +77,36 @@ class CTC(Brick):
return tensor.log(c).sum(axis=0)
- def best_path_decoding(y_hat, y_hat_mask=None):
- # Easy one !
- pass
+ def best_path_decoding(probs, probs_mask=None):
+ # probs is T x B x C+1
+ T = probs.shape[0]
+ B = probs.shape[1]
+ C = probs.shape[2]-1
+
+ maxprob = probs.argmax(axis=2)
+
+ # returns two values :
+ # label : (T x) T x B
+ # label_length : (T x) B
+ def recursion(maxp, p_mask, label_length, label):
+ label_length = label_length[-1]
+ label = label[-1]
+
+ nonzero = p_mask * tensor.ne(maxp, C)
+ nonzero_id = nonzero.nonzero()[0]
+
+ new_label = tensor.set_subtensor(label[label_length[nonzero_id], nonzero_id], maxp[nonzero_id])
+ new_label_length = tensor.switch(nonzero, label_length + numpy.int32(1), label_length)
+
+ return new_label_length, new_label
+
+ label_length, label = tensor.scan(fn=recursion,
+ sequences=[maxprob, probs_mask],
+ outputs_info=[tensor.zeros((B),dtype='int32'),tensor.zeros((T,B))])
+
+ return label[-1], label_length[-1]
- def prefix_search(y_hat, y_hat_mask=None):
+ def prefix_search(probs, probs_mask=None):
# Hard one...
pass