diff options
Diffstat (limited to 'apply_model.py')
-rw-r--r-- | apply_model.py | 43 |
1 files changed, 0 insertions, 43 deletions
diff --git a/apply_model.py b/apply_model.py deleted file mode 100644 index f0156fa..0000000 --- a/apply_model.py +++ /dev/null @@ -1,43 +0,0 @@ -import theano - -from blocks.graph import ComputationGraph - -class Apply(object): - def __init__(self, outputs, return_vars, stream): - if not isinstance(outputs, list): - outputs = [outputs] - if not isinstance(return_vars, list): - return_vars = [return_vars] - - self.outputs = outputs - self.return_vars = return_vars - self.stream = stream - - cg = ComputationGraph(self.outputs) - self.input_names = [i.name for i in cg.inputs] - self.f = theano.function(inputs=cg.inputs, outputs=self.outputs) - - def __iter__(self): - self.iterator = self.stream.get_epoch_iterator(as_dict=True) - while True: - try: - batch = next(self.iterator) - except StopIteration: - return - - inputs = [batch[n] for n in self.input_names] - outputs = self.f(*inputs) - - def find_retvar(name): - for idx, ov in enumerate(self.outputs): - if ov.name == name: - return outputs[idx] - - if name in batch: - return batch[name] - - raise ValueError('Variable ' + name + ' neither in outputs or in batch variables.') - - yield {name: find_retvar(name) for name in self.return_vars} - - |