diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
commit | 6d946f29f7548c75e97f30c4356dbac200ee6cce (patch) | |
tree | 387e586c7ad0c1a0167d21451c9a8c877cf3ef0e /model/__init__.py | |
parent | 1e6d08b0c9ac5983691b182631c71e9d46ee71cc (diff) | |
download | taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.tar.gz taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.zip |
Refactor models, clean the code and separate training from testing.
Diffstat (limited to 'model/__init__.py')
-rw-r--r-- | model/__init__.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/model/__init__.py b/model/__init__.py index e69de29..5c051f4 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -0,0 +1,36 @@ +from blocks.bricks import application, Initializable +from blocks.bricks.lookup import LookupTable + + +class ContextEmbedder(Initializable): + def __init__(self, config, **kwargs): + super(ContextEmbedder, self).__init__(**kwargs) + self.dim_embeddings = config.dim_embeddings + self.embed_weights_init = config.embed_weights_init + + self.inputs = [ name for (name, _, _) in self.dim_embeddings ] + self.outputs = [ '%s_embedded' % name for name in self.inputs ] + + self.lookups = { name: LookupTable(name='%s_lookup' % name) for name in self.inputs } + self.children = self.lookups.values() + + def _push_allocation_config(self): + for (name, num, dim) in self.dim_embeddings: + self.lookups[name].length = num + self.lookups[name].dim = dim + + def _push_initialization_config(self): + for name in self.inputs: + self.lookups[name].weights_init = self.embed_weights_init + + @application + def apply(self, **kwargs): + return tuple(self.lookups[name].apply(kwargs[name]) for name in self.inputs) + + @apply.property('inputs') + def apply_inputs(self): + return self.inputs + + @apply.property('outputs') + def apply_outputs(self): + return self.outputs |