diff options
author | Thomas Mesnard <thomas.mesnard@ens.fr> | 2016-03-01 00:27:15 +0100 |
---|---|---|
committer | Thomas Mesnard <thomas.mesnard@ens.fr> | 2016-03-02 09:28:39 +0100 |
commit | f31caf61be87850f3afcd367d6eb9521b2f613da (patch) | |
tree | 2bcceeb702ef0d35bfdc925977797c40290b6966 /train.py | |
download | deepmind-qa-f31caf61be87850f3afcd367d6eb9521b2f613da.tar.gz deepmind-qa-f31caf61be87850f3afcd367d6eb9521b2f613da.zip |
Initial commit
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/train.py b/train.py new file mode 100755 index 0000000..84f9a4b --- /dev/null +++ b/train.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +import logging +import numpy +import sys +import os +import importlib + +import theano +from theano import tensor + +from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar +from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring +from blocks.graph import ComputationGraph +from blocks.main_loop import MainLoop +from blocks.model import Model +from blocks.algorithms import GradientDescent + +try: + from blocks.extras.extensions.plot import Plot + plot_avail = True +except ImportError: + plot_avail = False + print "No plotting extension available." + +import data +from paramsaveload import SaveLoadParams + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +sys.setrecursionlimit(500000) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, 'Usage: %s config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[1] + config = importlib.import_module('.%s' % model_name, 'config') + + # Build datastream + path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training") + valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation") + vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt") + + ds, train_stream = data.setup_datastream(path, vocab_path, config) + _, valid_stream = data.setup_datastream(valid_path, vocab_path, config) + + dump_path = os.path.join("model_params", model_name+".pkl") + + # Build model + m = config.Model(config, ds.vocab_size) + + # Build the Blocks stuff for training + model = Model(m.sgd_cost) + + algorithm = GradientDescent(cost=m.sgd_cost, + step_rule=config.step_rule, + parameters=model.parameters) + + extensions = [ + TrainingDataMonitoring( + [v for l in m.monitor_vars for v in l], + prefix='train', + every_n_batches=config.print_freq) + ] + if config.save_freq is not None and dump_path is not None: + extensions += [ + SaveLoadParams(path=dump_path, + model=model, + before_training=True, + after_training=True, + after_epoch=True, + every_n_batches=config.save_freq) + ] + if valid_stream is not None and config.valid_freq != -1: + extensions += [ + DataStreamMonitoring( + [v for l in m.monitor_vars_valid for v in l], + valid_stream, + prefix='valid', + every_n_batches=config.valid_freq), + ] + if plot_avail: + plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv] + for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)] + extensions += [ + Plot(document='deepmind_qa_'+model_name, + channels=plot_channels, + # server_url='http://localhost:5006/', # If you need, change this + every_n_batches=config.print_freq) + ] + extensions += [ + Printing(every_n_batches=config.print_freq, + after_epoch=True), + ProgressBar() + ] + + main_loop = MainLoop( + model=model, + data_stream=train_stream, + algorithm=algorithm, + extensions=extensions + ) + + # Run the model ! + main_loop.run() + main_loop.profile.report() + + + +# vim: set sts=4 ts=4 sw=4 tw=0 et : |