diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-06-12 02:45:14 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-06-12 02:45:14 -0400 |
commit | 557d0fa74de74b8dbd8618a972725a7a9926e452 (patch) | |
tree | a3b63399c9353fc5e29ac793b754fa373a00b415 /train.py | |
parent | db0e57fc2a351cedef3b1270bf6047e9cae9fa9d (diff) | |
download | taxi-557d0fa74de74b8dbd8618a972725a7a9926e452.tar.gz taxi-557d0fa74de74b8dbd8618a972725a7a9926e452.zip |
Fix RNN validation
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 12 |
1 files changed, 9 insertions, 3 deletions
@@ -100,6 +100,12 @@ if __name__ == "__main__": cg = ComputationGraph(cost) monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) + valid_monitored = monitored + if hasattr(model, 'valid_cost'): + valid_cost = model.valid_cost(**inputs) + valid_cg = ComputationGraph(valid_cost) + valid_monitored = set([valid_cost] + VariableFilter(roles=[roles.COST])(valid_cg.variables)) + if hasattr(config, 'dropout') and config.dropout < 1.0: cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout) if hasattr(config, 'noise') and config.noise > 0.0: @@ -124,7 +130,7 @@ if __name__ == "__main__": ]), params=params) - plot_vars = [['valid_' + x.name for x in monitored]] + plot_vars = [['valid_' + x.name for x in valid_monitored]] logger.info('Plotted variables: %s' % str(plot_vars)) dump_path = os.path.join('model_data', model_name) @@ -136,13 +142,13 @@ if __name__ == "__main__": dump_ext.manager = CustomDumpManager(dump_path) extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000), - DataStreamMonitoring(monitored, valid_stream, + DataStreamMonitoring(valid_monitored, valid_stream, prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), Plot(model_name, channels=plot_vars, every_n_batches=500), load_dump_ext, - dump_ext + dump_ext ] main_loop = MainLoop( |