aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 16:06:25 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 16:06:25 -0400
commit6744c6e7b69c206c19954d65981a5139ee61363e (patch)
tree7cbf0d9be24efb37aad25f49a979db6ae8bd74af
parent57fe795d14e70c06c9bdbe6fe903588b5f75474e (diff)
downloadtaxi-6744c6e7b69c206c19954d65981a5139ee61363e.tar.gz
taxi-6744c6e7b69c206c19954d65981a5139ee61363e.zip
Fix validation cost computation to not use dropout/noise regularization
-rwxr-xr-xtrain.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/train.py b/train.py
index 9f636c0..65f9050 100755
--- a/train.py
+++ b/train.py
@@ -78,12 +78,11 @@ if __name__ == "__main__":
cost = model.cost(**inputs)
cg = ComputationGraph(cost)
- unmonitor = set()
+ monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables))
+
if hasattr(config, 'dropout') and config.dropout < 1.0:
- unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables))
cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout)
if hasattr(config, 'noise') and config.noise > 0.0:
- unmonitor.update(VariableFilter(roles=[roles.COST])(cg.variables))
cg = apply_noise(cg, config.noise_inputs(cg), config.noise)
cost = cg.outputs[0]
cg = Model(cost)
@@ -105,7 +104,6 @@ if __name__ == "__main__":
]),
params=params)
- monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) - unmonitor
plot_vars = [['valid_' + x.name for x in monitored]]
logger.info('Plotted variables: %s' % str(plot_vars))