diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-24 16:09:48 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-24 16:09:48 -0400 |
commit | 7dab7e47ce0e8c5ae996821794450a9ad3186cd3 (patch) | |
tree | e0babcc305696a6e6a67a52acecd300bfdf22cf0 /model/memory_network_bidir.py | |
parent | 60e6bc64d8e3c6679a6e2a960513c656d481f0ed (diff) | |
download | taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.tar.gz taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.zip |
Fix memory network
Diffstat (limited to 'model/memory_network_bidir.py')
-rw-r--r-- | model/memory_network_bidir.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py index cc99312..81e6440 100644 --- a/model/memory_network_bidir.py +++ b/model/memory_network_bidir.py @@ -72,22 +72,25 @@ class RecurrentEncoder(Initializable): return outputs + @apply.property('inputs') + def apply_inputs(self): + return self.inputs + class Model(MemoryNetworkBase): def __init__(self, config, **kwargs): # Build prefix encoder : recurrent then MLP - prefix_encoder = RecurrentEncoder(self.config.prefix_encoder, - self.config.representation_size, - self.config.representation_activation(), + prefix_encoder = RecurrentEncoder(config.prefix_encoder, + config.representation_size, + config.representation_activation(), name='prefix_encoder') # Build candidate encoder - candidate_encoder = RecurrentEncoder(self.config.candidate_encoder, - self.config.representation_size, - self.config.representation_activation(), + candidate_encoder = RecurrentEncoder(config.candidate_encoder, + config.representation_size, + config.representation_activation(), name='candidate_encoder') # And... that's it! super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs) - |