aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network_bidir.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/memory_network_bidir.py')
-rw-r--r--model/memory_network_bidir.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py
index cc99312..81e6440 100644
--- a/model/memory_network_bidir.py
+++ b/model/memory_network_bidir.py
@@ -72,22 +72,25 @@ class RecurrentEncoder(Initializable):
return outputs
+ @apply.property('inputs')
+ def apply_inputs(self):
+ return self.inputs
+
class Model(MemoryNetworkBase):
def __init__(self, config, **kwargs):
# Build prefix encoder : recurrent then MLP
- prefix_encoder = RecurrentEncoder(self.config.prefix_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
+ prefix_encoder = RecurrentEncoder(config.prefix_encoder,
+ config.representation_size,
+ config.representation_activation(),
name='prefix_encoder')
# Build candidate encoder
- candidate_encoder = RecurrentEncoder(self.config.candidate_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
+ candidate_encoder = RecurrentEncoder(config.candidate_encoder,
+ config.representation_size,
+ config.representation_activation(),
name='candidate_encoder')
# And... that's it!
super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)
-