aboutsummaryrefslogtreecommitdiff
path: root/model/rnn.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-23 12:43:55 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-23 12:43:55 -0400
commitb61a411fbcb98b09ee83f8dd124113c6d7f47737 (patch)
tree662b83806fa2f3769a121d9b8c6b1f0b6297a3a2 /model/rnn.py
parent028402e4a2fafc39cc8fc0e036e79017b9f9c26a (diff)
downloadtaxi-b61a411fbcb98b09ee83f8dd124113c6d7f47737.tar.gz
taxi-b61a411fbcb98b09ee83f8dd124113c6d7f47737.zip
Fix rnn validation
Diffstat (limited to 'model/rnn.py')
-rw-r--r--model/rnn.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/model/rnn.py b/model/rnn.py
index 0ee9586..d3f6616 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -134,7 +134,7 @@ class RNN(Initializable):
@application(outputs=['cost'])
def valid_cost(self, **kwargs):
last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64')
- return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])].mean()
+ return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[0])].mean()
@valid_cost.property('inputs')
def valid_cost_inputs(self):