aboutsummaryrefslogtreecommitdiff
path: root/model/bidirectional.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-23 15:13:15 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-23 15:13:15 -0400
commit8792dd327159b6eebbca167c325629b42e54766f (patch)
tree8dda287f62b5ba1a255d139bb5e8c036d8aab835 /model/bidirectional.py
parent6d5dd44c5b6c6e0f7eecc5997b484b40f35c0074 (diff)
downloadtaxi-8792dd327159b6eebbca167c325629b42e54766f.tar.gz
taxi-8792dd327159b6eebbca167c325629b42e54766f.zip
Fix bidirectional
Diffstat (limited to 'model/bidirectional.py')
-rw-r--r--model/bidirectional.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/model/bidirectional.py b/model/bidirectional.py
index 149d1fd..af3891d 100644
--- a/model/bidirectional.py
+++ b/model/bidirectional.py
@@ -59,9 +59,10 @@ class BidiRNN(Initializable):
longitude = tensor.shape_padright(longitude)
rec_in = tensor.concatenate((latitude, longitude), axis=2)
+ last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64')
path = self.rec.apply(self.fork.apply(rec_in), mask=latitude_mask)[0]
path_representation = (path[0][:, -self.config.hidden_state_dim:],
- path[-1][:, :self.config.hidden_state_dim])
+ path[last_id - 1, tensor.arange(latitude_mask.shape[1])][:, :self.config.hidden_state_dim])
embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs }))