From 0e52813a4e435cddeb80ec9972b2bf1fc791c0cc Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Sun, 21 Jun 2015 14:33:47 -0400 Subject: Updates for Blocks compatibility --- data/hdf5.py | 2 +- train.py | 76 ++++++++++++++++++++++++++++++++++-------------------------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/data/hdf5.py b/data/hdf5.py index 6e2f9a4..f3d6da2 100644 --- a/data/hdf5.py +++ b/data/hdf5.py @@ -13,7 +13,7 @@ class TaxiDataset(H5PYDataset): def __init__(self, which_set, filename='data.hdf5', **kwargs): self.filename = filename kwargs.setdefault('load_in_memory', True) - super(TaxiDataset, self).__init__(self.data_path, which_set, **kwargs) + super(TaxiDataset, self).__init__(self.data_path, (which_set,), **kwargs) @property def data_path(self): diff --git a/train.py b/train.py index 94c00d2..0bc697b 100755 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python2 import importlib import logging @@ -11,11 +11,15 @@ from theano import tensor from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule -from blocks.extensions import Printing, FinishAfter +from blocks.extensions import Printing, FinishAfter, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -from blocks.extensions.plot import Plot -from blocks.extensions.saveload import LoadFromDump, Dump -from blocks.dump import MainLoopDumpManager + +try: + from blocks.extensions.plot import Plot + use_plot = True +except ImportError: + use_plot = False + from blocks.filter import VariableFilter from blocks.graph import ComputationGraph, apply_dropout, apply_noise from blocks.main_loop import MainLoop @@ -54,25 +58,31 @@ class ElementwiseRemoveNotFinite(StepRule): return step, [] -class CustomDumpManager(MainLoopDumpManager): - def dump(self, main_loop): - """Dumps the main loop to the root folder. - See :mod:`blocks.dump`. - Overwrites the old data if present. - """ - if not os.path.exists(self.folder): - os.mkdir(self.folder) - self.dump_parameters(main_loop) - self.dump_log(main_loop) - - def load(self): - return (self.load_parameters(), self.load_log()) - - def load_to(self, main_loop): - """Loads the dump from the root folder into the main loop.""" - parameters, log = self.load() - main_loop.model.set_param_values(parameters) - main_loop.log = log +class SaveLoadParams(SimpleExtension): + def __init__(self, path, model, **kwargs): + super(SaveLoadParams, self).__init__(**kwargs) + + self.path = path + self.model = model + + def do_save(self): + with open(self.path, 'w') as f: + logger.info('Saving parameters to %s...'%self.path) + cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + + def do_load(self): + try: + with open(self.path, 'r') as f: + logger.info('Loading parameters from %s...'%self.path) + self.model.set_param_values(cPickle.load(f)) + except IOError: + pass + + def do(self, which_callback, *args): + if which_callback == 'before_training': + self.do_load() + else: + self.do_save() if __name__ == "__main__": if len(sys.argv) != 2: @@ -133,23 +143,23 @@ if __name__ == "__main__": plot_vars = [['valid_' + x.name for x in valid_monitored]] logger.info('Plotted variables: %s' % str(plot_vars)) - dump_path = os.path.join('model_data', model_name) + dump_path = os.path.join('model_data', model_name) + '.pkl' logger.info('Dump path: %s' % dump_path) - load_dump_ext = LoadFromDump(dump_path) - dump_ext = Dump(dump_path, every_n_batches=1000) - load_dump_ext.manager = CustomDumpManager(dump_path) - dump_ext.manager = CustomDumpManager(dump_path) - extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000), DataStreamMonitoring(valid_monitored, valid_stream, prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), - Plot(model_name, channels=plot_vars, every_n_batches=500), - load_dump_ext, - dump_ext + + SaveLoadParams(dump_path, cg, + before_training=True, # before training -> load params + every_n_batches=1000, # every N batches -> save params + ), ] + + if use_plot: + extensions.append(Plot(model_name, channels=plot_vars, every_n_batches=500)) main_loop = MainLoop( model=cg, -- cgit v1.2.3