diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 12:49:03 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 12:49:18 -0400 |
commit | a4b190516d00428b1d8a81686a3291e5fa5f9865 (patch) | |
tree | 230f04cbb664d4f7138ca4f22839e6bf501b32be /train.py | |
parent | 859bee7196c78e9828d9182b5fea2ad2eab01f1d (diff) | |
download | taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.tar.gz taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.zip |
Make the testing into an extension run at each validation
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 69 |
1 files changed, 9 insertions, 60 deletions
@@ -1,6 +1,5 @@ #!/usr/bin/env python2 -import cPickle import importlib import logging import operator @@ -12,7 +11,7 @@ from theano import tensor from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum -from blocks.extensions import Printing, FinishAfter, SimpleExtension +from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring import blocks blocks.config.default_seed = 123 @@ -28,66 +27,11 @@ from blocks.graph import ComputationGraph, apply_dropout, apply_noise from blocks.main_loop import MainLoop from blocks.model import Model +from ext_saveload import SaveLoadParams +from ext_test import RunOnTest logger = logging.getLogger(__name__) - -class ElementwiseRemoveNotFinite(StepRule): - """A step rule that replaces non-finite coefficients by zeros. - - Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step - (the parameter update of a single shared variable) - with a scaled version of the parameters being updated instead. - - Parameters - ---------- - scaler : float, optional - The scaling applied to the parameter in case the step contains - non-finite elements. Defaults to 0.1. - - Notes - ----- - This trick was originally used in the GroundHog_ framework. - - .. _GroundHog: https://github.com/lisa-groundhog/GroundHog - - """ - def __init__(self, scaler=0.1): - self.scaler = scaler - - def compute_step(self, param, previous_step): - not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) - step = tensor.switch(not_finite, self.scaler * param, previous_step) - - return step, [] - -class SaveLoadParams(SimpleExtension): - def __init__(self, path, model, **kwargs): - super(SaveLoadParams, self).__init__(**kwargs) - - self.path = path - self.model = model - - def do_save(self): - with open(self.path, 'w') as f: - logger.info('Saving parameters to %s...'%self.path) - cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) - logger.info('Done saving.') - - def do_load(self): - try: - with open(self.path, 'r') as f: - logger.info('Loading parameters from %s...'%self.path) - self.model.set_param_values(cPickle.load(f)) - except IOError: - pass - - def do(self, which_callback, *args): - if which_callback == 'before_training': - self.do_load() - else: - self.do_save() - if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] @@ -143,7 +87,7 @@ if __name__ == "__main__": algorithm = GradientDescent( cost=cost, step_rule=CompositeRule([ - ElementwiseRemoveNotFinite(), + RemoveNotFinite(), step_rule ]), params=params) @@ -166,6 +110,11 @@ if __name__ == "__main__": after_epoch=True, # after epoch -> save params after_training=True, # after training -> save params ), + + RunOnTest(model_name, + model, + stream, + every_n_batches=1000), ] if use_plot: |