aboutsummaryrefslogtreecommitdiff
path: root/apply_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'apply_model.py')
-rw-r--r--apply_model.py43
1 files changed, 0 insertions, 43 deletions
diff --git a/apply_model.py b/apply_model.py
deleted file mode 100644
index f0156fa..0000000
--- a/apply_model.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import theano
-
-from blocks.graph import ComputationGraph
-
-class Apply(object):
- def __init__(self, outputs, return_vars, stream):
- if not isinstance(outputs, list):
- outputs = [outputs]
- if not isinstance(return_vars, list):
- return_vars = [return_vars]
-
- self.outputs = outputs
- self.return_vars = return_vars
- self.stream = stream
-
- cg = ComputationGraph(self.outputs)
- self.input_names = [i.name for i in cg.inputs]
- self.f = theano.function(inputs=cg.inputs, outputs=self.outputs)
-
- def __iter__(self):
- self.iterator = self.stream.get_epoch_iterator(as_dict=True)
- while True:
- try:
- batch = next(self.iterator)
- except StopIteration:
- return
-
- inputs = [batch[n] for n in self.input_names]
- outputs = self.f(*inputs)
-
- def find_retvar(name):
- for idx, ov in enumerate(self.outputs):
- if ov.name == name:
- return outputs[idx]
-
- if name in batch:
- return batch[name]
-
- raise ValueError('Variable ' + name + ' neither in outputs or in batch variables.')
-
- yield {name: find_retvar(name) for name in self.return_vars}
-
-