diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-22 16:06:25 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-22 16:06:25 -0400 |
commit | 6744c6e7b69c206c19954d65981a5139ee61363e (patch) | |
tree | 7cbf0d9be24efb37aad25f49a979db6ae8bd74af | |
parent | 57fe795d14e70c06c9bdbe6fe903588b5f75474e (diff) | |
download | taxi-6744c6e7b69c206c19954d65981a5139ee61363e.tar.gz taxi-6744c6e7b69c206c19954d65981a5139ee61363e.zip |
Fix validation cost computation to not use dropout/noise regularization
-rwxr-xr-x | train.py | 6 |
1 files changed, 2 insertions, 4 deletions
@@ -78,12 +78,11 @@ if __name__ == "__main__": cost = model.cost(**inputs) cg = ComputationGraph(cost) - unmonitor = set() + monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) + if hasattr(config, 'dropout') and config.dropout < 1.0: - unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables)) cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout) if hasattr(config, 'noise') and config.noise > 0.0: - unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables)) cg = apply_noise(cg, config.noise_inputs(cg), config.noise) cost = cg.outputs[0] cg = Model(cost) @@ -105,7 +104,6 @@ if __name__ == "__main__": ]), params=params) - monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) - unmonitor plot_vars = [['valid_' + x.name for x in monitored]] logger.info('Plotted variables: %s' % str(plot_vars)) |