aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/memory_network.py')
-rw-r--r--model/memory_network.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/model/memory_network.py b/model/memory_network.py
index 5acbfe3..1afc9cb 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -21,11 +21,11 @@ class Model(Initializable):
self.context_embedder = ContextEmbedder(config)
- self.prefix_encoder = MLP(activations=[Rectifier() for _ in config.prefix_encoder.dim_hidden],
- dims=[config.prefix_encoder.dim_input] + config.prefix_encoder.dim_hidden,
+ self.prefix_encoder = MLP(activations=[Rectifier() for _ in config.prefix_encoder.dim_hidden] + [config.representation_activation()],
+ dims=[config.prefix_encoder.dim_input] + config.prefix_encoder.dim_hidden + [config.representation_size],
name='prefix_encoder')
- self.candidate_encoder = MLP(activations=[Rectifier() for _ in config.candidate_encoder.dim_hidden],
- dims=[config.candidate_encoder.dim_input] + config.candidate_encoder.dim_hidden,
+ self.candidate_encoder = MLP(activations=[Rectifier() for _ in config.candidate_encoder.dim_hidden] + [config.representation_activation()],
+ dims=[config.candidate_encoder.dim_input] + config.candidate_encoder.dim_hidden + [config.representation_size],
name='candidate_encoder')
self.softmax = Softmax()
@@ -46,11 +46,15 @@ class Model(Initializable):
prefix_extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v] for k, v in self.prefix_extremities.items())
prefix_inputs = tensor.concatenate(prefix_extremities + prefix_embeddings, axis=1)
prefix_representation = self.prefix_encoder.apply(prefix_inputs)
+ if self.config.normalize_representation:
+ prefix_representation = prefix_representation / tensor.sqrt((prefix_representation ** 2).sum(axis=1, keepdims=True))
candidate_embeddings = tuple(self.context_embedder.apply(**{k: kwargs['candidate_%s'%k] for k in self.context_embedder.inputs }))
candidate_extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v] for k, v in self.candidate_extremities.items())
candidate_inputs = tensor.concatenate(candidate_extremities + candidate_embeddings, axis=1)
candidate_representation = self.candidate_encoder.apply(candidate_inputs)
+ if self.config.normalize_representation:
+ candidate_representation = candidate_representation / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True))
similarity_score = tensor.dot(prefix_representation, candidate_representation.T)
similarity = self.softmax.apply(similarity_score)