diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 10:27:15 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 10:27:15 -0400 |
commit | 1e64a442e78b5e471b2f573295bd9a747b7c6c3f (patch) | |
tree | 836d804e16bb19cb08eab7bce4c7c2b5a5e7489a /model/memory_network_bidir.py | |
parent | 389bfd3637dfb523a3e4194c7281a0c538166546 (diff) | |
download | taxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.tar.gz taxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.zip |
Memory net refactoring
Diffstat (limited to 'model/memory_network_bidir.py')
-rw-r--r-- | model/memory_network_bidir.py | 70 |
1 files changed, 10 insertions, 60 deletions
diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py index 9dad091..cc99312 100644 --- a/model/memory_network_bidir.py +++ b/model/memory_network_bidir.py @@ -75,69 +75,19 @@ class RecurrentEncoder(Initializable): class Model(MemoryNetworkBase): def __init__(self, config, **kwargs): - super(Model, self).__init__(config, **kwargs) # Build prefix encoder : recurrent then MLP - self.prefix_encoder = RecurrentEncoder(self.config.prefix_encoder, - self.config.representation_size, - self.config.representation_activation(), - name='prefix_encoder') + prefix_encoder = RecurrentEncoder(self.config.prefix_encoder, + self.config.representation_size, + self.config.representation_activation(), + name='prefix_encoder') # Build candidate encoder - self.candidate_encoder = RecurrentEncoder(self.config.candidate_encoder, - self.config.representation_size, - self.config.representation_activation(), - name='candidate_encoder') + candidate_encoder = RecurrentEncoder(self.config.candidate_encoder, + self.config.representation_size, + self.config.representation_activation(), + name='candidate_encoder') - # Rest of the stuff - self.softmax = Softmax() + # And... that's it! + super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs) - self.inputs = self.prefix_encoder.inputs \ - + ['candidate_'+k for k in self.candidate_encoder.inputs] - - self.children = [ self.prefix_encoder, - self.candidate_encoder, - self.softmax ] - - - @application(outputs=['destination']) - def predict(self, **kwargs): - prefix_representation = self.prefix_encoder.apply( - **{ name: kwargs[name] for name in self.prefix_encoder.inputs }) - - candidate_representation = self.prefix_encoder.apply( - **{ name: kwargs['candidate_'+name] for name in self.candidate_encoder.inputs }) - - if self.config.normalize_representation: - candidate_representation = candidate_representation \ - / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True)) - - similarity_score = tensor.dot(prefix_representation, candidate_representation.T) - similarity = self.softmax.apply(similarity_score) - - candidate_mask = kwargs['candidate_latitude_mask'] - candidate_last = tensor.cast(candidate_mask.sum(axis=1) - 1, 'int64') - candidate_destination = tensor.concatenate( - (kwargs['candidate_latitude'][tensor.arange(candidate_mask.shape[0]), candidate_last] - [:, None], - kwargs['candidate_longitude'][tensor.arange(candidate_mask.shape[0]), candidate_last] - [:, None]), - axis=1) - - return tensor.dot(similarity, candidate_destination) - - @predict.property('inputs') - def predict_inputs(self): - return self.inputs - - @application(outputs=['cost']) - def cost(self, **kwargs): - y_hat = self.predict(**kwargs) - y = tensor.concatenate((kwargs['destination_latitude'][:, None], - kwargs['destination_longitude'][:, None]), axis=1) - - return error.erdist(y_hat, y).mean() - - @cost.property('inputs') - def cost_inputs(self): - return self.inputs + ['destination_latitude', 'destination_longitude'] |