diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-23 12:43:55 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-23 12:43:55 -0400 |
commit | b61a411fbcb98b09ee83f8dd124113c6d7f47737 (patch) | |
tree | 662b83806fa2f3769a121d9b8c6b1f0b6297a3a2 | |
parent | 028402e4a2fafc39cc8fc0e036e79017b9f9c26a (diff) | |
download | taxi-b61a411fbcb98b09ee83f8dd124113c6d7f47737.tar.gz taxi-b61a411fbcb98b09ee83f8dd124113c6d7f47737.zip |
Fix rnn validation
-rw-r--r-- | model/rnn.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/model/rnn.py b/model/rnn.py index 0ee9586..d3f6616 100644 --- a/model/rnn.py +++ b/model/rnn.py @@ -134,7 +134,7 @@ class RNN(Initializable): @application(outputs=['cost']) def valid_cost(self, **kwargs): last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64') - return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])].mean() + return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[0])].mean() @valid_cost.property('inputs') def valid_cost_inputs(self): |