From 6d946f29f7548c75e97f30c4356dbac200ee6cce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Mon, 18 May 2015 16:22:00 -0400 Subject: Refactor models, clean the code and separate training from testing. --- model/__init__.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'model/__init__.py') diff --git a/model/__init__.py b/model/__init__.py index e69de29..5c051f4 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -0,0 +1,36 @@ +from blocks.bricks import application, Initializable +from blocks.bricks.lookup import LookupTable + + +class ContextEmbedder(Initializable): + def __init__(self, config, **kwargs): + super(ContextEmbedder, self).__init__(**kwargs) + self.dim_embeddings = config.dim_embeddings + self.embed_weights_init = config.embed_weights_init + + self.inputs = [ name for (name, _, _) in self.dim_embeddings ] + self.outputs = [ '%s_embedded' % name for name in self.inputs ] + + self.lookups = { name: LookupTable(name='%s_lookup' % name) for name in self.inputs } + self.children = self.lookups.values() + + def _push_allocation_config(self): + for (name, num, dim) in self.dim_embeddings: + self.lookups[name].length = num + self.lookups[name].dim = dim + + def _push_initialization_config(self): + for name in self.inputs: + self.lookups[name].weights_init = self.embed_weights_init + + @application + def apply(self, **kwargs): + return tuple(self.lookups[name].apply(kwargs[name]) for name in self.inputs) + + @apply.property('inputs') + def apply_inputs(self): + return self.inputs + + @apply.property('outputs') + def apply_outputs(self): + return self.outputs -- cgit v1.2.3