aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2015-12-23 18:35:56 +0100
committerAlex Auvolat <alex@adnab.me>2015-12-23 18:35:56 +0100
commita1394aad6fca2dd560eb45a9b2e4cbc7be4c2bf7 (patch)
tree99d7bdb78b052867589699b031ec50cdf932a142 /main.py
parent08b63add743087ac0e2bbb7f739605642b3edc7b (diff)
downloadpgm-ctc-a1394aad6fca2dd560eb45a9b2e4cbc7be4c2bf7.tar.gz
pgm-ctc-a1394aad6fca2dd560eb45a9b2e4cbc7be4c2bf7.zip
stuff
Diffstat (limited to 'main.py')
-rw-r--r--main.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/main.py b/main.py
index d384edb..b71d339 100644
--- a/main.py
+++ b/main.py
@@ -92,7 +92,7 @@ y_hat.name = 'y_hat'
y_hat_mask = x_mask
# Cost
-cost = CTC().apply(y, y_hat, y_mask, y_hat_mask)
+cost = CTC().apply(y, y_hat, y_mask.sum(axis=1), y_hat_mask).mean()
cost.name = 'CTC'
# Initialization