diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-22 15:50:07 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-22 15:51:46 -0400 |
commit | c2c88c48a0404de0eb834df71fa53ae63fdfd1c7 (patch) | |
tree | 85df14a0e73a93710c3508a5657da1a4034abe2d /apply_model.py | |
parent | 6c45eb6e48775dcbbbd3177f02c1d1b0c161ba1e (diff) | |
download | taxi-c2c88c48a0404de0eb834df71fa53ae63fdfd1c7.tar.gz taxi-c2c88c48a0404de0eb834df71fa53ae63fdfd1c7.zip |
Delete useless file apply_model.py
Diffstat (limited to 'apply_model.py')
-rw-r--r-- | apply_model.py | 43 |
1 files changed, 0 insertions, 43 deletions
diff --git a/apply_model.py b/apply_model.py deleted file mode 100644 index f0156fa..0000000 --- a/apply_model.py +++ /dev/null @@ -1,43 +0,0 @@ -import theano - -from blocks.graph import ComputationGraph - -class Apply(object): - def __init__(self, outputs, return_vars, stream): - if not isinstance(outputs, list): - outputs = [outputs] - if not isinstance(return_vars, list): - return_vars = [return_vars] - - self.outputs = outputs - self.return_vars = return_vars - self.stream = stream - - cg = ComputationGraph(self.outputs) - self.input_names = [i.name for i in cg.inputs] - self.f = theano.function(inputs=cg.inputs, outputs=self.outputs) - - def __iter__(self): - self.iterator = self.stream.get_epoch_iterator(as_dict=True) - while True: - try: - batch = next(self.iterator) - except StopIteration: - return - - inputs = [batch[n] for n in self.input_names] - outputs = self.f(*inputs) - - def find_retvar(name): - for idx, ov in enumerate(self.outputs): - if ov.name == name: - return outputs[idx] - - if name in batch: - return batch[name] - - raise ValueError('Variable ' + name + ' neither in outputs or in batch variables.') - - yield {name: find_retvar(name) for name in self.return_vars} - - |