aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/memory_network_bidir.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py
index 81e6440..dd447bf 100644
--- a/model/memory_network_bidir.py
+++ b/model/memory_network_bidir.py
@@ -13,6 +13,8 @@ from model import ContextEmbedder
from memory_network import StreamRecurrent as Stream
from memory_network import MemoryNetworkBase
+from bidirectional import SegregatedBidirectional
+
class RecurrentEncoder(Initializable):
def __init__(self, config, output_dim, activation, **kwargs):
@@ -21,11 +23,12 @@ class RecurrentEncoder(Initializable):
self.config = config
self.context_embedder = ContextEmbedder(config)
- self.rec = Bidirectional(LSTM(dim=config.rec_state_dim, name='encoder_recurrent'))
- self.fork = Fork(
- [name for name in self.rec.prototype.apply.sequences
- if name != 'mask'],
- prototype=Linear())
+ self.rec = SegregatedBidirectional(LSTM(dim=config.rec_state_dim, name='encoder_recurrent'))
+
+ self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
+ prototype=Linear(), name='fwd_fork')
+ self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
+ prototype=Linear(), name='bkwd_fork')
rto_in = config.rec_state_dim * 2 + sum(x[2] for x in config.dim_embeddings)
self.rec_to_output = MLP(
@@ -33,15 +36,16 @@ class RecurrentEncoder(Initializable):
dims=[rto_in] + config.dim_hidden + [output_dim],
name='encoder_rto')
- self.children = [self.context_embedder, self.rec, self.fork, self.rec_to_output]
+ self.children = [self.context_embedder, self.rec, self.fwd_fork, self.bkwd_fork, self.rec_to_output]
self.rec_inputs = ['latitude', 'longitude', 'latitude_mask']
self.inputs = self.context_embedder.inputs + self.rec_inputs
def _push_allocation_config(self):
- self.fork.input_dim = 2
- self.fork.output_dims = [ self.rec.children[0].get_dim(name)
- for name in self.fork.output_names ]
+ for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]):
+ fork.input_dim = 2
+ fork.output_dims = [ self.rec.children[i].get_dim(name)
+ for name in fork.output_names ]
def _push_initialization_config(self):
for brick in self.children:
@@ -56,7 +60,10 @@ class RecurrentEncoder(Initializable):
rec_in = tensor.concatenate((latitude[:, :, None], longitude[:, :, None]),
axis=2)
- path = self.rec.apply(self.fork.apply(rec_in), mask=latitude_mask)[0]
+ path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True),
+ {'mask': latitude_mask}),
+ merge(self.bkwd_fork.apply(rec_in, as_dict=True),
+ {'mask': latitude_mask}))[0]
last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64')