aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 10:27:15 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 10:27:15 -0400
commit1e64a442e78b5e471b2f573295bd9a747b7c6c3f (patch)
tree836d804e16bb19cb08eab7bce4c7c2b5a5e7489a /model/memory_network.py
parent389bfd3637dfb523a3e4194c7281a0c538166546 (diff)
downloadtaxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.tar.gz
taxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.zip
Memory net refactoring
Diffstat (limited to 'model/memory_network.py')
-rw-r--r--model/memory_network.py49
1 files changed, 48 insertions, 1 deletions
diff --git a/model/memory_network.py b/model/memory_network.py
index e7ba51c..84a8edf 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -14,11 +14,58 @@ import error
from model import ContextEmbedder
class MemoryNetworkBase(Initializable):
- def __init__(self, config, **kwargs):
+ def __init__(self, config, prefix_encoder, candidate_encoder, **kwargs):
super(MemoryNetworkBase, self).__init__(**kwargs)
+ self.prefix_encoder = prefix_encoder
+ self.candidate_encoder = candidate_encoder
self.config = config
+ self.softmax = Softmax()
+ self.children = [ self.softmax, prefix_encoder, candidate_encoder ]
+
+ self.inputs = self.prefix_encoder.apply.inputs \
+ + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs]
+
+ def candidate_destination(**kwargs):
+ return tensor.concatenate(
+ (tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]),
+ tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])),
+ axis=1)
+
+ @application(outputs=['cost'])
+ def cost(self, **kwargs):
+ y_hat = self.predict(**kwargs)
+ y = tensor.concatenate((kwargs['destination_latitude'][:, None],
+ kwargs['destination_longitude'][:, None]), axis=1)
+
+ return error.erdist(y_hat, y).mean()
+
+ @application(outputs=['destination'])
+ def predict(self, **kwargs):
+ prefix_representation = self.prefix_encoder.apply(
+ { x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
+ candidate_representatin = self.candidate_encoder.apply(
+ { x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })
+
+ if self.config.normalize_representation:
+ prefix_representation = prefix_representation \
+ / tensor.sqrt((prefix_representation ** 2).sum(axis=1, keepdims=True))
+ candidate_representation = candidate_representation \
+ / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True))
+
+ similarity_score = tensor.dot(prefix_representation, candidate_representation.T)
+ similarity = self.softmax.apply(similarity_score)
+
+ return tensor.dot(similarity, self.candidate_destination(**kwargs))
+
+ @predict.property('inputs')
+ def predict_inputs(self):
+ return self.inputs
+
+ @cost.property('inputs')
+ def cost_inputs(self):
+ return self.inputs + ['destination_latitude', 'destination_longitude']
class StreamBase(object):
def __init__(self, config):