diff options
-rw-r--r-- | ext_saveload.py | 4 | ||||
-rwxr-xr-x | train.py | 20 |
2 files changed, 14 insertions, 10 deletions
diff --git a/ext_saveload.py b/ext_saveload.py index cc7c47a..059c5cf 100644 --- a/ext_saveload.py +++ b/ext_saveload.py @@ -15,14 +15,14 @@ class SaveLoadParams(SimpleExtension): def do_save(self): with open(self.path, 'w') as f: logger.info('Saving parameters to %s...'%self.path) - cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) logger.info('Done saving.') def do_load(self): try: with open(self.path, 'r') as f: logger.info('Loading parameters from %s...'%self.path) - self.model.set_param_values(cPickle.load(f)) + self.model.set_parameter_values(cPickle.load(f)) except IOError: pass @@ -21,7 +21,7 @@ blocks.config.default_seed = 123 fuel.config.default_seed = 123 try: - from blocks.extras.extensions.plotting import Plot + from blocks.extras.extensions.plot import Plot use_plot = True except ImportError: use_plot = False @@ -77,10 +77,10 @@ if __name__ == "__main__": logger.info('# Parameter shapes:') parameters_size = 0 - for key, value in cg.get_params().iteritems(): - logger.info(' %20s %s' % (value.get_value().shape, key)) + for value in cg.parameters: + logger.info(' %20s %s' % (value.get_value().shape, value.name)) parameters_size += reduce(operator.mul, value.get_value().shape, 1) - logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params()))) + logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.parameters))) if hasattr(config, 'step_rule'): step_rule = config.step_rule @@ -97,9 +97,10 @@ if __name__ == "__main__": RemoveNotFinite(), step_rule ]), - params=params) + parameters=params) - plot_vars = [['valid_' + x.name for x in valid_monitored]] + plot_vars = [['valid_' + x.name for x in valid_monitored] + + ['train_' + x.name for x in valid_monitored]] logger.info('Plotted variables: %s' % str(plot_vars)) dump_path = os.path.join('model_data', model_name) + '.pkl' @@ -110,7 +111,7 @@ if __name__ == "__main__": prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), - # FinishAfter(every_n_batches=10), + FinishAfter(every_n_batches=10000000), SaveLoadParams(dump_path, cg, before_training=True, # before training -> load params @@ -126,7 +127,10 @@ if __name__ == "__main__": ] if use_plot: - extensions.append(Plot(model_name, channels=plot_vars, every_n_batches=500)) + extensions.append(Plot(model_name, + channels=plot_vars, + every_n_batches=500, + server_url='http://eos6:5006/')) main_loop = MainLoop( model=cg, |