aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network_bidir.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-24 16:09:48 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-24 16:09:48 -0400
commit7dab7e47ce0e8c5ae996821794450a9ad3186cd3 (patch)
treee0babcc305696a6e6a67a52acecd300bfdf22cf0 /model/memory_network_bidir.py
parent60e6bc64d8e3c6679a6e2a960513c656d481f0ed (diff)
downloadtaxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.tar.gz
taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.zip
Fix memory network
Diffstat (limited to 'model/memory_network_bidir.py')
-rw-r--r--model/memory_network_bidir.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py
index cc99312..81e6440 100644
--- a/model/memory_network_bidir.py
+++ b/model/memory_network_bidir.py
@@ -72,22 +72,25 @@ class RecurrentEncoder(Initializable):
return outputs
+ @apply.property('inputs')
+ def apply_inputs(self):
+ return self.inputs
+
class Model(MemoryNetworkBase):
def __init__(self, config, **kwargs):
# Build prefix encoder : recurrent then MLP
- prefix_encoder = RecurrentEncoder(self.config.prefix_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
+ prefix_encoder = RecurrentEncoder(config.prefix_encoder,
+ config.representation_size,
+ config.representation_activation(),
name='prefix_encoder')
# Build candidate encoder
- candidate_encoder = RecurrentEncoder(self.config.candidate_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
+ candidate_encoder = RecurrentEncoder(config.candidate_encoder,
+ config.representation_size,
+ config.representation_activation(),
name='candidate_encoder')
# And... that's it!
super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)
-