diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 12:49:03 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 12:49:18 -0400 |
commit | a4b190516d00428b1d8a81686a3291e5fa5f9865 (patch) | |
tree | 230f04cbb664d4f7138ca4f22839e6bf501b32be | |
parent | 859bee7196c78e9828d9182b5fea2ad2eab01f1d (diff) | |
download | taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.tar.gz taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.zip |
Make the testing into an extension run at each validation
-rw-r--r-- | ext_saveload.py | 33 | ||||
-rw-r--r-- | ext_test.py | 66 | ||||
-rwxr-xr-x | test.py | 54 | ||||
-rwxr-xr-x | train.py | 69 |
4 files changed, 108 insertions, 114 deletions
diff --git a/ext_saveload.py b/ext_saveload.py new file mode 100644 index 0000000..cc7c47a --- /dev/null +++ b/ext_saveload.py @@ -0,0 +1,33 @@ +import cPickle +import logging + +from blocks.extensions import SimpleExtension + +logger = logging.getLogger(__name__) + +class SaveLoadParams(SimpleExtension): + def __init__(self, path, model, **kwargs): + super(SaveLoadParams, self).__init__(**kwargs) + + self.path = path + self.model = model + + def do_save(self): + with open(self.path, 'w') as f: + logger.info('Saving parameters to %s...'%self.path) + cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + logger.info('Done saving.') + + def do_load(self): + try: + with open(self.path, 'r') as f: + logger.info('Loading parameters from %s...'%self.path) + self.model.set_param_values(cPickle.load(f)) + except IOError: + pass + + def do(self, which_callback, *args): + if which_callback == 'before_training': + self.do_load() + else: + self.do_save() diff --git a/ext_test.py b/ext_test.py new file mode 100644 index 0000000..6a3fa0a --- /dev/null +++ b/ext_test.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +import logging +import os +import csv + +from blocks.model import Model +from blocks.extensions import SimpleExtension + +logger = logging.getLogger(__name__) + +class RunOnTest(SimpleExtension): + def __init__(self, model_name, model, stream, **kwargs): + super(RunOnTest, self).__init__(**kwargs) + + self.model_name = model_name + + cg = Model(model.predict(**stream.inputs())) + + self.inputs = cg.inputs + self.outputs = model.predict.outputs + + req_vars_test = model.predict.inputs + ['trip_id'] + self.test_stream = stream.test(req_vars_test) + + self.function = cg.get_theano_function() + + def do(self, which_callback, *args): + iter_no = repr(self.main_loop.log.status['iterations_done']) + if 'valid_destination_cost' in self.main_loop.log.current_row: + dvc = self.main_loop.log.current_row['valid_destination_cost'] + else: + dvc = self.main_loop.log.current_row['valid_model_cost_cost'] + if 'valid_time_cost' in self.main_loop.log.current_row: + tvc = self.main_loop.log.current_row['valid_time_cost'] + else: + tvc = self.main_loop.log.current_row['valid_model_cost_cost'] + + if 'destination' in self.outputs: + dest_outname = 'test-dest-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, dvc) + dest_outfile = open(os.path.join('output', dest_outname), 'w') + dest_outcsv = csv.writer(dest_outfile) + dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) + logger.info("Generating output for test set: %s" % dest_outname) + if 'duration' in self.outputs: + time_outname = 'test-time-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, tvc) + time_outfile = open(os.path.join('output', time_outname), 'w') + time_outcsv = csv.writer(time_outfile) + time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) + logger.info("Generating output for test set: %s" % time_outname) + + for d in self.test_stream.get_epoch_iterator(as_dict=True): + input_values = [d[k.name] for k in self.inputs] + output_values = self.function(*input_values) + if 'destination' in self.outputs: + destination = output_values[self.outputs.index('destination')] + dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]]) + if 'duration' in self.outputs: + duration = output_values[self.outputs.index('duration')] + time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))]) + + if 'destination' in self.outputs: + dest_outfile.close() + if 'duration' in self.outputs: + time_outfile.close() + diff --git a/test.py b/test.py deleted file mode 100755 index 4925b27..0000000 --- a/test.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python - -import cPickle -import sys -import os -import importlib -import csv - -from blocks.model import Model - - -if __name__ == "__main__": - if len(sys.argv) != 2: - print >> sys.stderr, 'Usage: %s config' % sys.argv[0] - sys.exit(1) - model_name = sys.argv[1] - config = importlib.import_module('.%s' % model_name, 'config') - model_config = config.Model(config) - - stream = config.Stream(config) - inputs = stream.inputs() - outputs = model_config.predict.outputs - req_vars_test = model_config.predict.inputs + ['trip_id'] - test_stream = stream.test(req_vars_test) - - model = Model(model_config.predict(**inputs)) - with open(os.path.join('model_data', "{}.pkl".format(model_name))) as f: - parameters = cPickle.load(f) - model.set_param_values(parameters) - - if 'destination' in outputs: - dest_outfile = open(os.path.join('output', 'test-dest-output-%s.csv' % model_name), 'w') - dest_outcsv = csv.writer(dest_outfile) - dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) - if 'duration' in outputs: - time_outfile = open(os.path.join('output', 'test-time-output-%s.csv' % model_name), 'w') - time_outcsv = csv.writer(time_outfile) - time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) - - function = model.get_theano_function() - for d in test_stream.get_epoch_iterator(as_dict=True): - input_values = [d[k.name] for k in model.inputs] - output_values = function(*input_values) - if 'destination' in outputs: - destination = output_values[outputs.index('destination')] - dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]]) - if 'duration' in outputs: - duration = output_values[outputs.index('duration')] - time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))]) - - if 'destination' in outputs: - dest_outfile.close() - if 'duration' in outputs: - time_outfile.close() @@ -1,6 +1,5 @@ #!/usr/bin/env python2 -import cPickle import importlib import logging import operator @@ -12,7 +11,7 @@ from theano import tensor from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum -from blocks.extensions import Printing, FinishAfter, SimpleExtension +from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring import blocks blocks.config.default_seed = 123 @@ -28,66 +27,11 @@ from blocks.graph import ComputationGraph, apply_dropout, apply_noise from blocks.main_loop import MainLoop from blocks.model import Model +from ext_saveload import SaveLoadParams +from ext_test import RunOnTest logger = logging.getLogger(__name__) - -class ElementwiseRemoveNotFinite(StepRule): - """A step rule that replaces non-finite coefficients by zeros. - - Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step - (the parameter update of a single shared variable) - with a scaled version of the parameters being updated instead. - - Parameters - ---------- - scaler : float, optional - The scaling applied to the parameter in case the step contains - non-finite elements. Defaults to 0.1. - - Notes - ----- - This trick was originally used in the GroundHog_ framework. - - .. _GroundHog: https://github.com/lisa-groundhog/GroundHog - - """ - def __init__(self, scaler=0.1): - self.scaler = scaler - - def compute_step(self, param, previous_step): - not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step) - step = tensor.switch(not_finite, self.scaler * param, previous_step) - - return step, [] - -class SaveLoadParams(SimpleExtension): - def __init__(self, path, model, **kwargs): - super(SaveLoadParams, self).__init__(**kwargs) - - self.path = path - self.model = model - - def do_save(self): - with open(self.path, 'w') as f: - logger.info('Saving parameters to %s...'%self.path) - cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) - logger.info('Done saving.') - - def do_load(self): - try: - with open(self.path, 'r') as f: - logger.info('Loading parameters from %s...'%self.path) - self.model.set_param_values(cPickle.load(f)) - except IOError: - pass - - def do(self, which_callback, *args): - if which_callback == 'before_training': - self.do_load() - else: - self.do_save() - if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] @@ -143,7 +87,7 @@ if __name__ == "__main__": algorithm = GradientDescent( cost=cost, step_rule=CompositeRule([ - ElementwiseRemoveNotFinite(), + RemoveNotFinite(), step_rule ]), params=params) @@ -166,6 +110,11 @@ if __name__ == "__main__": after_epoch=True, # after epoch -> save params after_training=True, # after training -> save params ), + + RunOnTest(model_name, + model, + stream, + every_n_batches=1000), ] if use_plot: |