aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/rnn.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/model/rnn.py b/model/rnn.py
index 0ee9586..d3f6616 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -134,7 +134,7 @@ class RNN(Initializable):
@application(outputs=['cost'])
def valid_cost(self, **kwargs):
last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64')
- return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])].mean()
+ return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[0])].mean()
@valid_cost.property('inputs')
def valid_cost_inputs(self):