blob: 5c051f46a186c13993ae3ad0455c46e5e2ff05f1 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
from blocks.bricks import application, Initializable
from blocks.bricks.lookup import LookupTable
class ContextEmbedder(Initializable):
def __init__(self, config, **kwargs):
super(ContextEmbedder, self).__init__(**kwargs)
self.dim_embeddings = config.dim_embeddings
self.embed_weights_init = config.embed_weights_init
self.inputs = [ name for (name, _, _) in self.dim_embeddings ]
self.outputs = [ '%s_embedded' % name for name in self.inputs ]
self.lookups = { name: LookupTable(name='%s_lookup' % name) for name in self.inputs }
self.children = self.lookups.values()
def _push_allocation_config(self):
for (name, num, dim) in self.dim_embeddings:
self.lookups[name].length = num
self.lookups[name].dim = dim
def _push_initialization_config(self):
for name in self.inputs:
self.lookups[name].weights_init = self.embed_weights_init
@application
def apply(self, **kwargs):
return tuple(self.lookups[name].apply(kwargs[name]) for name in self.inputs)
@apply.property('inputs')
def apply_inputs(self):
return self.inputs
@apply.property('outputs')
def apply_outputs(self):
return self.outputs
|