From 1e64a442e78b5e471b2f573295bd9a747b7c6c3f Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 24 Jul 2015 10:27:15 -0400 Subject: Memory net refactoring --- model/memory_network.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) (limited to 'model/memory_network.py') diff --git a/model/memory_network.py b/model/memory_network.py index e7ba51c..84a8edf 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -14,11 +14,58 @@ import error from model import ContextEmbedder class MemoryNetworkBase(Initializable): - def __init__(self, config, **kwargs): + def __init__(self, config, prefix_encoder, candidate_encoder, **kwargs): super(MemoryNetworkBase, self).__init__(**kwargs) + self.prefix_encoder = prefix_encoder + self.candidate_encoder = candidate_encoder self.config = config + self.softmax = Softmax() + self.children = [ self.softmax, prefix_encoder, candidate_encoder ] + + self.inputs = self.prefix_encoder.apply.inputs \ + + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] + + def candidate_destination(**kwargs): + return tensor.concatenate( + (tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]), + tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])), + axis=1) + + @application(outputs=['cost']) + def cost(self, **kwargs): + y_hat = self.predict(**kwargs) + y = tensor.concatenate((kwargs['destination_latitude'][:, None], + kwargs['destination_longitude'][:, None]), axis=1) + + return error.erdist(y_hat, y).mean() + + @application(outputs=['destination']) + def predict(self, **kwargs): + prefix_representation = self.prefix_encoder.apply( + { x: kwargs[x] for x in self.prefix_encoder.apply.inputs }) + candidate_representatin = self.candidate_encoder.apply( + { x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs }) + + if self.config.normalize_representation: + prefix_representation = prefix_representation \ + / tensor.sqrt((prefix_representation ** 2).sum(axis=1, keepdims=True)) + candidate_representation = candidate_representation \ + / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True)) + + similarity_score = tensor.dot(prefix_representation, candidate_representation.T) + similarity = self.softmax.apply(similarity_score) + + return tensor.dot(similarity, self.candidate_destination(**kwargs)) + + @predict.property('inputs') + def predict_inputs(self): + return self.inputs + + @cost.property('inputs') + def cost_inputs(self): + return self.inputs + ['destination_latitude', 'destination_longitude'] class StreamBase(object): def __init__(self, config): -- cgit v1.2.3