aboutsummaryrefslogtreecommitdiff
path: root/edit_distance.py
diff options
context:
space:
mode:
Diffstat (limited to 'edit_distance.py')
-rw-r--r--edit_distance.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/edit_distance.py b/edit_distance.py
new file mode 100644
index 0000000..d76cc00
--- /dev/null
+++ b/edit_distance.py
@@ -0,0 +1,24 @@
+import numpy
+import theano
+from theano import tensor
+
+@theano.compile.ops.as_op(itypes=[tensor.imatrix, tensor.ivector, tensor.imatrix, tensor.ivector],
+ otypes=[tensor.ivector])
+def batch_edit_distance(a, a_len, b, b_len):
+ B = a.shape[0]
+ assert b.shape[0] == B
+
+ q = max(a.shape[1], b.shape[1]) * numpy.ones((B, a.shape[1]+1, b.shape[1]+1), dtype='int32')
+ q[:, 0, 0] = 0
+
+ for i in range(a.shape[1]+1):
+ for j in range(b.shape[1]+1):
+ if i > 0:
+ q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i-1, j]+1)
+ if j > 0:
+ q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i, j-1]+1)
+ if i > 0 and j > 0:
+ q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i-1, j-1]+numpy.not_equal(a[:, i-1], b[:, j-1]))
+ return q[numpy.arange(B), a_len, b_len]
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :