diff options
Diffstat (limited to 'apply_model.py')
-rw-r--r-- | apply_model.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/apply_model.py b/apply_model.py new file mode 100644 index 0000000..f0156fa --- /dev/null +++ b/apply_model.py @@ -0,0 +1,43 @@ +import theano + +from blocks.graph import ComputationGraph + +class Apply(object): + def __init__(self, outputs, return_vars, stream): + if not isinstance(outputs, list): + outputs = [outputs] + if not isinstance(return_vars, list): + return_vars = [return_vars] + + self.outputs = outputs + self.return_vars = return_vars + self.stream = stream + + cg = ComputationGraph(self.outputs) + self.input_names = [i.name for i in cg.inputs] + self.f = theano.function(inputs=cg.inputs, outputs=self.outputs) + + def __iter__(self): + self.iterator = self.stream.get_epoch_iterator(as_dict=True) + while True: + try: + batch = next(self.iterator) + except StopIteration: + return + + inputs = [batch[n] for n in self.input_names] + outputs = self.f(*inputs) + + def find_retvar(name): + for idx, ov in enumerate(self.outputs): + if ov.name == name: + return outputs[idx] + + if name in batch: + return batch[name] + + raise ValueError('Variable ' + name + ' neither in outputs or in batch variables.') + + yield {name: find_retvar(name) for name in self.return_vars} + + |