diff options
author | Alex Auvolat <alex@adnab.me> | 2016-03-08 13:26:28 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2016-03-08 13:26:28 +0100 |
commit | 2f479926c16d2911d0dd878c21de082abfc5b237 (patch) | |
tree | b399e9ad9af04a9449334dff1a47449808b7ca13 /irc.py | |
parent | 23093608e0edc43477c3a2ed804ae1016790f7e4 (diff) | |
download | text-rnn-2f479926c16d2911d0dd878c21de082abfc5b237.tar.gz text-rnn-2f479926c16d2911d0dd878c21de082abfc5b237.zip |
Revive project
Diffstat (limited to 'irc.py')
-rw-r--r-- | irc.py | 173 |
1 files changed, 173 insertions, 0 deletions
@@ -0,0 +1,173 @@ +#!/usr/bin/env python2 + +import logging +import sys +import importlib + +import theano + +from blocks.extensions import Printing, SimpleExtension, FinishAfter +from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring + +from blocks.graph import ComputationGraph +from blocks.main_loop import MainLoop +from blocks.model import Model +from blocks.algorithms import GradientDescent + +try: + from blocks.extras.extensions.plot import Plot + plot_avail = False +except ImportError: + plot_avail = False + + +import datastream +from paramsaveload import SaveLoadParams +from gentext import GenText +from ircext import IRCClientExt + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +sys.setrecursionlimit(500000) + + +class ResetStates(SimpleExtension): + def __init__(self, state_vars, **kwargs): + super(ResetStates, self).__init__(**kwargs) + + self.f = theano.function( + inputs=[], outputs=[], + updates=[(v, v.zeros_like()) for v in state_vars]) + + def do(self, which_callback, *args): + self.f() + +if __name__ == "__main__": + if len(sys.argv) < 2: + print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[-1] + config = importlib.import_module('%s' % model_name) + + + # Build datastream + train_stream = datastream.setup_datastream('data/logcompil.txt', + config.num_seqs, + config.seq_len, + config.seq_div_size) + + # Build model + m = config.Model() + m.pred.name = 'pred' + + # Train the model + saveloc = 'model_data/%s-%s' % (model_name, config.param_desc) + train_model(m, train_stream, dump_path=saveloc) + + + # Define the model + model = Model(m.sgd_cost) + + # IRC mode : just load the parameters and run an IRC server + if '--irc' in sys.argv: + try: + extensions.append(FinishAfter(before_training=True, after_n_batches=1)) + print "Initializing main loop" + main_loop.run() + print "Jumping into IRC" + irc.run_forever() + except KeyboardInterrupt: + pass + sys.exit(0) + + # Train the model + + cg = ComputationGraph(m.sgd_cost) + algorithm = GradientDescent(cost=m.sgd_cost, + step_rule=config.step_rule, + parameters=cg.parameters) + + algorithm.add_updates(m.states) + + monitor_vars = [v for p in m.monitor_vars for v in p] + extensions = [ + TrainingDataMonitoring( + monitor_vars, + prefix='train', every_n_epochs=1), + Printing(every_n_epochs=1, after_epoch=False), + + ResetStates([v for v, _ in m.states], after_epoch=True) + ] + if plot_avail: + plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars] + extensions.append( + Plot(document='text_'+model_name, + channels=plot_channels, + server_url='http://localhost:5006', + every_n_epochs=1, after_epoch=False) + ) + if config.save_freq is not None and dump_path is not None: + extensions.append( + SaveLoadParams(path=dump_path+'.pkl', + model=model, + before_training=True, + after_training=True, + after_epoch=False, + every_n_epochs=config.save_freq) + ) + if config.sample_freq is not None: + extensions.append( + GenText(m, '\nalex\ttu crois ?\n', + config.sample_len, config.sample_temperature, + every_n_epochs=config.sample_freq, + after_epoch=False, before_training=True) + ) + if config.on_irc: + irc = IRCClientExt(m, config.sample_temperature, + server='clipper.ens.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True) + irc.do('before_training') + extensions.append(irc) + + if config.on_irc: + irc = IRCClientExt(m, config.sample_temperature, + server='clipper.ens.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True) + irc.do('before_training') + extensions.append(irc) + + main_loop = MainLoop( + model=model, + data_stream=train_stream, + algorithm=algorithm, + extensions=extensions + ) + main_loop.run() + + # IRC mode : just load the parameters and run an IRC server + if '--irc' in sys.argv: + try: + extensions.append(FinishAfter(before_training=True, after_n_batches=1)) + print "Initializing main loop" + main_loop.run() + print "Jumping into IRC" + irc.run_forever() + except KeyboardInterrupt: + pass + sys.exit(0) + + + + + + + + +# vim: set sts=4 ts=4 sw=4 tw=0 et : |