From 12304944033d20bbc5c1b3f5cb90cf8dedebcdff Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 17 Jun 2015 14:58:38 -0400 Subject: paramsaveload --- paramsaveload.py | 37 +++++++++++++++++++++++++++++++++++++ train.py | 49 ++++++++++++++++++++++++++++++++----------------- 2 files changed, 69 insertions(+), 17 deletions(-) create mode 100644 paramsaveload.py diff --git a/paramsaveload.py b/paramsaveload.py new file mode 100644 index 0000000..7181e9a --- /dev/null +++ b/paramsaveload.py @@ -0,0 +1,37 @@ +import logging + +import numpy + +import cPickle + +from blocks.extensions import SimpleExtension + +logging.basicConfig(level='INFO') +logger = logging.getLogger('extensions.SaveLoadParams') + +class SaveLoadParams(SimpleExtension): + def __init__(self, path, model, **kwargs): + super(SaveLoadParams, self).__init__(**kwargs) + + self.path = path + self.model = model + + def do_save(self): + with open(self.path, 'w') as f: + logger.info('Saving parameters to %s...'%self.path) + cPickle.dump(model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + + def do_load(self): + try: + with open(self.path, 'r') as f: + logger.info('Loading parameters from %s...'%self.path) + model.set_parma_values(cPickle.load(f)) + except IOError: + pass + + def do(self, which_callback, *args): + if which_callback == 'before_training': + self.do_load() + else: + self.do_save() + diff --git a/train.py b/train.py index 79b2116..525724f 100755 --- a/train.py +++ b/train.py @@ -14,7 +14,7 @@ from theano.tensor.shared_randomstreams import RandomStreams from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -from blocks.extras.extensions.plot import Plot +# from blocks.extras.extensions.plot import Plot from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop @@ -22,6 +22,7 @@ from blocks.model import Model from blocks.algorithms import GradientDescent import datastream +import paramsaveload import gentext import ircext @@ -60,17 +61,31 @@ def train_model(m, train_stream, dump_path=None): algorithm.add_updates(m.states) - # Load the parameters from a dumped model - if dump_path is not None: - try: - with closing(numpy.load(dump_path)) as source: - logger.info('Loading parameters...') - param_values = {'/' + name.replace(BRICK_DELIMITER, '/'): source[name] - for name in source.keys() - if name != 'pkl' and not 'None' in name} - model.set_param_values(param_values) - except IOError: - pass + extensions = [] + if config.save_freq is not None and dump_path is not None: + extensions.append( + SaveLoadParams(path=dump_path, + model=model, + before_training=True, + after_epoch=False, + every_n_epochs=config.save_freq) + ) + if config.sample_freq is not None: + extensions.append( + gentext.GenText(m, '\nalex\ttu crois ?\n', + config.sample_len, config.sample_temperature, + every_n_epochs=config.sample_freq, + after_epoch=False, before_training=True) + ) + if config.on_irc: + extensions.append( + ircext.IRCClientExt(m, config.sample_temperature, + server='irc.ulminfo.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True) + ) extensions = [] if config.save_freq is not None: @@ -106,11 +121,11 @@ def train_model(m, train_stream, dump_path=None): [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], prefix='train', every_n_epochs=1), Printing(every_n_epochs=1, after_epoch=False), - Plot(document='text_'+model_name+'_'+config.param_desc, - channels=[['train_cost', 'train_cost_reg'], - ['train_error_rate', 'train_error_rate_reg']], - server_url='http://eos21:4201/', - every_n_epochs=1, after_epoch=False), + # Plot(document='text_'+model_name+'_'+config.param_desc, + # channels=[['train_cost', 'train_cost_reg'], + # ['train_error_rate', 'train_error_rate_reg']], + # server_url='http://eos21:4201/', + # every_n_epochs=1, after_epoch=False), ResetStates([v for v, _ in m.states], after_epoch=True) ] -- cgit v1.2.3