diff options
author | Thomas Mesnard <thomas.mesnard@ens.fr> | 2015-12-28 20:35:38 +0100 |
---|---|---|
committer | Thomas Mesnard <thomas.mesnard@ens.fr> | 2015-12-28 20:35:38 +0100 |
commit | e8e37dee0c5c846b1aa2dd24dc99095191f72a9b (patch) | |
tree | d033f04eaca8178ada7ee966c4d8e56df45a6ace /edit_distance.py | |
parent | c9ba2abc7172b4657216e0fcc638098060d7f753 (diff) | |
download | pgm-ctc-e8e37dee0c5c846b1aa2dd24dc99095191f72a9b.tar.gz pgm-ctc-e8e37dee0c5c846b1aa2dd24dc99095191f72a9b.zip |
Kind of works
Diffstat (limited to 'edit_distance.py')
-rw-r--r-- | edit_distance.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/edit_distance.py b/edit_distance.py new file mode 100644 index 0000000..d76cc00 --- /dev/null +++ b/edit_distance.py @@ -0,0 +1,24 @@ +import numpy +import theano +from theano import tensor + +@theano.compile.ops.as_op(itypes=[tensor.imatrix, tensor.ivector, tensor.imatrix, tensor.ivector], + otypes=[tensor.ivector]) +def batch_edit_distance(a, a_len, b, b_len): + B = a.shape[0] + assert b.shape[0] == B + + q = max(a.shape[1], b.shape[1]) * numpy.ones((B, a.shape[1]+1, b.shape[1]+1), dtype='int32') + q[:, 0, 0] = 0 + + for i in range(a.shape[1]+1): + for j in range(b.shape[1]+1): + if i > 0: + q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i-1, j]+1) + if j > 0: + q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i, j-1]+1) + if i > 0 and j > 0: + q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i-1, j-1]+numpy.not_equal(a[:, i-1], b[:, j-1])) + return q[numpy.arange(B), a_len, b_len] + +# vim: set sts=4 ts=4 sw=4 tw=0 et : |