From 389d8001be77e6cacb35804236fe9d3f0930282b Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 6 Jul 2015 10:40:23 -0400 Subject: Blocks compatibility --- train.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index e3244ec..8260218 100755 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ blocks.config.default_seed = 123 fuel.config.default_seed = 123 try: - from blocks.extras.extensions.plotting import Plot + from blocks.extras.extensions.plot import Plot use_plot = True except ImportError: use_plot = False @@ -77,10 +77,10 @@ if __name__ == "__main__": logger.info('# Parameter shapes:') parameters_size = 0 - for key, value in cg.get_params().iteritems(): - logger.info(' %20s %s' % (value.get_value().shape, key)) + for value in cg.parameters: + logger.info(' %20s %s' % (value.get_value().shape, value.name)) parameters_size += reduce(operator.mul, value.get_value().shape, 1) - logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params()))) + logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.parameters))) if hasattr(config, 'step_rule'): step_rule = config.step_rule @@ -97,9 +97,10 @@ if __name__ == "__main__": RemoveNotFinite(), step_rule ]), - params=params) + parameters=params) - plot_vars = [['valid_' + x.name for x in valid_monitored]] + plot_vars = [['valid_' + x.name for x in valid_monitored] + + ['train_' + x.name for x in valid_monitored]] logger.info('Plotted variables: %s' % str(plot_vars)) dump_path = os.path.join('model_data', model_name) + '.pkl' @@ -110,7 +111,7 @@ if __name__ == "__main__": prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), - # FinishAfter(every_n_batches=10), + FinishAfter(every_n_batches=10000000), SaveLoadParams(dump_path, cg, before_training=True, # before training -> load params @@ -126,7 +127,10 @@ if __name__ == "__main__": ] if use_plot: - extensions.append(Plot(model_name, channels=plot_vars, every_n_batches=500)) + extensions.append(Plot(model_name, + channels=plot_vars, + every_n_batches=500, + server_url='http://eos6:5006/')) main_loop = MainLoop( model=cg, -- cgit v1.2.3