aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/train.py b/train.py
index 876fcba..94c00d2 100755
--- a/train.py
+++ b/train.py
@@ -100,6 +100,12 @@ if __name__ == "__main__":
cg = ComputationGraph(cost)
monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables))
+ valid_monitored = monitored
+ if hasattr(model, 'valid_cost'):
+ valid_cost = model.valid_cost(**inputs)
+ valid_cg = ComputationGraph(valid_cost)
+ valid_monitored = set([valid_cost] + VariableFilter(roles=[roles.COST])(valid_cg.variables))
+
if hasattr(config, 'dropout') and config.dropout < 1.0:
cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout)
if hasattr(config, 'noise') and config.noise > 0.0:
@@ -124,7 +130,7 @@ if __name__ == "__main__":
]),
params=params)
- plot_vars = [['valid_' + x.name for x in monitored]]
+ plot_vars = [['valid_' + x.name for x in valid_monitored]]
logger.info('Plotted variables: %s' % str(plot_vars))
dump_path = os.path.join('model_data', model_name)
@@ -136,13 +142,13 @@ if __name__ == "__main__":
dump_ext.manager = CustomDumpManager(dump_path)
extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000),
- DataStreamMonitoring(monitored, valid_stream,
+ DataStreamMonitoring(valid_monitored, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
Plot(model_name, channels=plot_vars, every_n_batches=500),
load_dump_ext,
- dump_ext
+ dump_ext
]
main_loop = MainLoop(