aboutsummaryrefslogtreecommitdiff
path: root/model/__init__.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-05-18 16:22:00 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-05-18 16:22:00 -0400
commit6d946f29f7548c75e97f30c4356dbac200ee6cce (patch)
tree387e586c7ad0c1a0167d21451c9a8c877cf3ef0e /model/__init__.py
parent1e6d08b0c9ac5983691b182631c71e9d46ee71cc (diff)
downloadtaxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.tar.gz
taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.zip
Refactor models, clean the code and separate training from testing.
Diffstat (limited to 'model/__init__.py')
-rw-r--r--model/__init__.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/model/__init__.py b/model/__init__.py
index e69de29..5c051f4 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -0,0 +1,36 @@
+from blocks.bricks import application, Initializable
+from blocks.bricks.lookup import LookupTable
+
+
+class ContextEmbedder(Initializable):
+ def __init__(self, config, **kwargs):
+ super(ContextEmbedder, self).__init__(**kwargs)
+ self.dim_embeddings = config.dim_embeddings
+ self.embed_weights_init = config.embed_weights_init
+
+ self.inputs = [ name for (name, _, _) in self.dim_embeddings ]
+ self.outputs = [ '%s_embedded' % name for name in self.inputs ]
+
+ self.lookups = { name: LookupTable(name='%s_lookup' % name) for name in self.inputs }
+ self.children = self.lookups.values()
+
+ def _push_allocation_config(self):
+ for (name, num, dim) in self.dim_embeddings:
+ self.lookups[name].length = num
+ self.lookups[name].dim = dim
+
+ def _push_initialization_config(self):
+ for name in self.inputs:
+ self.lookups[name].weights_init = self.embed_weights_init
+
+ @application
+ def apply(self, **kwargs):
+ return tuple(self.lookups[name].apply(kwargs[name]) for name in self.inputs)
+
+ @apply.property('inputs')
+ def apply_inputs(self):
+ return self.inputs
+
+ @apply.property('outputs')
+ def apply_outputs(self):
+ return self.outputs