diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 10:27:15 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 10:27:15 -0400 |
commit | 1e64a442e78b5e471b2f573295bd9a747b7c6c3f (patch) | |
tree | 836d804e16bb19cb08eab7bce4c7c2b5a5e7489a /model/memory_network.py | |
parent | 389bfd3637dfb523a3e4194c7281a0c538166546 (diff) | |
download | taxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.tar.gz taxi-1e64a442e78b5e471b2f573295bd9a747b7c6c3f.zip |
Memory net refactoring
Diffstat (limited to 'model/memory_network.py')
-rw-r--r-- | model/memory_network.py | 49 |
1 files changed, 48 insertions, 1 deletions
diff --git a/model/memory_network.py b/model/memory_network.py index e7ba51c..84a8edf 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -14,11 +14,58 @@ import error from model import ContextEmbedder class MemoryNetworkBase(Initializable): - def __init__(self, config, **kwargs): + def __init__(self, config, prefix_encoder, candidate_encoder, **kwargs): super(MemoryNetworkBase, self).__init__(**kwargs) + self.prefix_encoder = prefix_encoder + self.candidate_encoder = candidate_encoder self.config = config + self.softmax = Softmax() + self.children = [ self.softmax, prefix_encoder, candidate_encoder ] + + self.inputs = self.prefix_encoder.apply.inputs \ + + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] + + def candidate_destination(**kwargs): + return tensor.concatenate( + (tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]), + tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])), + axis=1) + + @application(outputs=['cost']) + def cost(self, **kwargs): + y_hat = self.predict(**kwargs) + y = tensor.concatenate((kwargs['destination_latitude'][:, None], + kwargs['destination_longitude'][:, None]), axis=1) + + return error.erdist(y_hat, y).mean() + + @application(outputs=['destination']) + def predict(self, **kwargs): + prefix_representation = self.prefix_encoder.apply( + { x: kwargs[x] for x in self.prefix_encoder.apply.inputs }) + candidate_representatin = self.candidate_encoder.apply( + { x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs }) + + if self.config.normalize_representation: + prefix_representation = prefix_representation \ + / tensor.sqrt((prefix_representation ** 2).sum(axis=1, keepdims=True)) + candidate_representation = candidate_representation \ + / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True)) + + similarity_score = tensor.dot(prefix_representation, candidate_representation.T) + similarity = self.softmax.apply(similarity_score) + + return tensor.dot(similarity, self.candidate_destination(**kwargs)) + + @predict.property('inputs') + def predict_inputs(self): + return self.inputs + + @cost.property('inputs') + def cost_inputs(self): + return self.inputs + ['destination_latitude', 'destination_longitude'] class StreamBase(object): def __init__(self, config): |