aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network_bidir.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 10:27:15 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 10:27:15 -0400
commit1e64a442e78b5e471b2f573295bd9a747b7c6c3f (patch)
tree836d804e16bb19cb08eab7bce4c7c2b5a5e7489a /model/memory_network_bidir.py
parent389bfd3637dfb523a3e4194c7281a0c538166546 (diff)
downloadtaxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.tar.gz
taxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.zip
Memory net refactoring
Diffstat (limited to 'model/memory_network_bidir.py')
-rw-r--r--model/memory_network_bidir.py70
1 files changed, 10 insertions, 60 deletions
diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py
index 9dad091..cc99312 100644
--- a/model/memory_network_bidir.py
+++ b/model/memory_network_bidir.py
@@ -75,69 +75,19 @@ class RecurrentEncoder(Initializable):
class Model(MemoryNetworkBase):
def __init__(self, config, **kwargs):
- super(Model, self).__init__(config, **kwargs)
# Build prefix encoder : recurrent then MLP
- self.prefix_encoder = RecurrentEncoder(self.config.prefix_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
- name='prefix_encoder')
+ prefix_encoder = RecurrentEncoder(self.config.prefix_encoder,
+ self.config.representation_size,
+ self.config.representation_activation(),
+ name='prefix_encoder')
# Build candidate encoder
- self.candidate_encoder = RecurrentEncoder(self.config.candidate_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
- name='candidate_encoder')
+ candidate_encoder = RecurrentEncoder(self.config.candidate_encoder,
+ self.config.representation_size,
+ self.config.representation_activation(),
+ name='candidate_encoder')
- # Rest of the stuff
- self.softmax = Softmax()
+ # And... that's it!
+ super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)
- self.inputs = self.prefix_encoder.inputs \
- + ['candidate_'+k for k in self.candidate_encoder.inputs]
-
- self.children = [ self.prefix_encoder,
- self.candidate_encoder,
- self.softmax ]
-
-
- @application(outputs=['destination'])
- def predict(self, **kwargs):
- prefix_representation = self.prefix_encoder.apply(
- **{ name: kwargs[name] for name in self.prefix_encoder.inputs })
-
- candidate_representation = self.prefix_encoder.apply(
- **{ name: kwargs['candidate_'+name] for name in self.candidate_encoder.inputs })
-
- if self.config.normalize_representation:
- candidate_representation = candidate_representation \
- / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True))
-
- similarity_score = tensor.dot(prefix_representation, candidate_representation.T)
- similarity = self.softmax.apply(similarity_score)
-
- candidate_mask = kwargs['candidate_latitude_mask']
- candidate_last = tensor.cast(candidate_mask.sum(axis=1) - 1, 'int64')
- candidate_destination = tensor.concatenate(
- (kwargs['candidate_latitude'][tensor.arange(candidate_mask.shape[0]), candidate_last]
- [:, None],
- kwargs['candidate_longitude'][tensor.arange(candidate_mask.shape[0]), candidate_last]
- [:, None]),
- axis=1)
-
- return tensor.dot(similarity, candidate_destination)
-
- @predict.property('inputs')
- def predict_inputs(self):
- return self.inputs
-
- @application(outputs=['cost'])
- def cost(self, **kwargs):
- y_hat = self.predict(**kwargs)
- y = tensor.concatenate((kwargs['destination_latitude'][:, None],
- kwargs['destination_longitude'][:, None]), axis=1)
-
- return error.erdist(y_hat, y).mean()
-
- @cost.property('inputs')
- def cost_inputs(self):
- return self.inputs + ['destination_latitude', 'destination_longitude']