aboutsummaryrefslogtreecommitdiff
path: root/apply_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'apply_model.py')
-rw-r--r--apply_model.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/apply_model.py b/apply_model.py
new file mode 100644
index 0000000..f0156fa
--- /dev/null
+++ b/apply_model.py
@@ -0,0 +1,43 @@
+import theano
+
+from blocks.graph import ComputationGraph
+
+class Apply(object):
+ def __init__(self, outputs, return_vars, stream):
+ if not isinstance(outputs, list):
+ outputs = [outputs]
+ if not isinstance(return_vars, list):
+ return_vars = [return_vars]
+
+ self.outputs = outputs
+ self.return_vars = return_vars
+ self.stream = stream
+
+ cg = ComputationGraph(self.outputs)
+ self.input_names = [i.name for i in cg.inputs]
+ self.f = theano.function(inputs=cg.inputs, outputs=self.outputs)
+
+ def __iter__(self):
+ self.iterator = self.stream.get_epoch_iterator(as_dict=True)
+ while True:
+ try:
+ batch = next(self.iterator)
+ except StopIteration:
+ return
+
+ inputs = [batch[n] for n in self.input_names]
+ outputs = self.f(*inputs)
+
+ def find_retvar(name):
+ for idx, ov in enumerate(self.outputs):
+ if ov.name == name:
+ return outputs[idx]
+
+ if name in batch:
+ return batch[name]
+
+ raise ValueError('Variable ' + name + ' neither in outputs or in batch variables.')
+
+ yield {name: find_retvar(name) for name in self.return_vars}
+
+