aboutsummaryrefslogtreecommitdiff
path: root/model/rnn.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-23 10:07:38 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-23 10:07:38 -0400
commitdd8ae5ea8ed0c7cb1a7880b1e1887c6e23cdf910 (patch)
tree2662e4cd27f2e2131e28462fd0e1579e84658be2 /model/rnn.py
parent9013799ebca1c426c3c3e9019eb71018b253b025 (diff)
downloadtaxi-dd8ae5ea8ed0c7cb1a7880b1e1887c6e23cdf910.tar.gz
taxi-dd8ae5ea8ed0c7cb1a7880b1e1887c6e23cdf910.zip
Fix RNN prediction function
Diffstat (limited to 'model/rnn.py')
-rw-r--r--model/rnn.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/model/rnn.py b/model/rnn.py
index bfc3122..0ee9586 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -99,7 +99,7 @@ class RNN(Initializable):
res = self.predict_all(**kwargs)[0]
last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=0) - 1, dtype='int64')
- return res[last_id]
+ return res[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])]
@predict.property('inputs')
def predict_inputs(self):
@@ -134,7 +134,7 @@ class RNN(Initializable):
@application(outputs=['cost'])
def valid_cost(self, **kwargs):
last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64')
- return self.cost_matrix(**kwargs)[last_id].mean()
+ return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])].mean()
@valid_cost.property('inputs')
def valid_cost_inputs(self):