#!/usr/bin/env python import logging import numpy import sys import os import importlib import theano from theano import tensor from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model from blocks.algorithms import GradientDescent try: from blocks.extras.extensions.plot import Plot plot_avail = True except ImportError: plot_avail = False print "No plotting extension available." import data from paramsaveload import SaveLoadParams logging.basicConfig(level='INFO') logger = logging.getLogger(__name__) sys.setrecursionlimit(500000) if __name__ == "__main__": if len(sys.argv) != 2: print >> sys.stderr, 'Usage: %s config' % sys.argv[0] sys.exit(1) model_name = sys.argv[1] config = importlib.import_module('.%s' % model_name, 'config') # Build datastream path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training") valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation") vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt") ds, train_stream = data.setup_datastream(path, vocab_path, config) _, valid_stream = data.setup_datastream(valid_path, vocab_path, config) dump_path = os.path.join("model_params", model_name+".pkl") # Build model m = config.Model(config, ds.vocab_size) # Build the Blocks stuff for training model = Model(m.sgd_cost) algorithm = GradientDescent(cost=m.sgd_cost, step_rule=config.step_rule, parameters=model.parameters) extensions = [ TrainingDataMonitoring( [v for l in m.monitor_vars for v in l], prefix='train', every_n_batches=config.print_freq) ] if config.save_freq is not None and dump_path is not None: extensions += [ SaveLoadParams(path=dump_path, model=model, before_training=True, after_training=True, after_epoch=True, every_n_batches=config.save_freq) ] if valid_stream is not None and config.valid_freq != -1: extensions += [ DataStreamMonitoring( [v for l in m.monitor_vars_valid for v in l], valid_stream, prefix='valid', every_n_batches=config.valid_freq), ] if plot_avail: plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv] for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)] extensions += [ Plot(document='deepmind_qa_'+model_name, channels=plot_channels, # server_url='http://localhost:5006/', # If you need, change this every_n_batches=config.print_freq) ] extensions += [ Printing(every_n_batches=config.print_freq, after_epoch=True), ProgressBar() ] main_loop = MainLoop( model=model, data_stream=train_stream, algorithm=algorithm, extensions=extensions ) # Run the model ! main_loop.run() main_loop.profile.report() # vim: set sts=4 ts=4 sw=4 tw=0 et :