diff options
Diffstat (limited to 'model/__init__.py')
-rw-r--r-- | model/__init__.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/model/__init__.py b/model/__init__.py index e69de29..5c051f4 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -0,0 +1,36 @@ +from blocks.bricks import application, Initializable +from blocks.bricks.lookup import LookupTable + + +class ContextEmbedder(Initializable): + def __init__(self, config, **kwargs): + super(ContextEmbedder, self).__init__(**kwargs) + self.dim_embeddings = config.dim_embeddings + self.embed_weights_init = config.embed_weights_init + + self.inputs = [ name for (name, _, _) in self.dim_embeddings ] + self.outputs = [ '%s_embedded' % name for name in self.inputs ] + + self.lookups = { name: LookupTable(name='%s_lookup' % name) for name in self.inputs } + self.children = self.lookups.values() + + def _push_allocation_config(self): + for (name, num, dim) in self.dim_embeddings: + self.lookups[name].length = num + self.lookups[name].dim = dim + + def _push_initialization_config(self): + for name in self.inputs: + self.lookups[name].weights_init = self.embed_weights_init + + @application + def apply(self, **kwargs): + return tuple(self.lookups[name].apply(kwargs[name]) for name in self.inputs) + + @apply.property('inputs') + def apply_inputs(self): + return self.inputs + + @apply.property('outputs') + def apply_outputs(self): + return self.outputs |