summaryrefslogblamecommitdiff
path: root/irc.py
blob: f8ca125f6d2e651fc1cdbf8979761befd733efb4 (plain) (tree)












































































































































































                                                                                     
#!/usr/bin/env python2

import logging
import sys
import importlib

import theano

from blocks.extensions import Printing, SimpleExtension, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring

from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.algorithms import GradientDescent

try:
    from blocks.extras.extensions.plot import Plot
    plot_avail = False
except ImportError:
    plot_avail = False


import datastream
from paramsaveload import SaveLoadParams
from gentext import GenText
from ircext import IRCClientExt

logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)

sys.setrecursionlimit(500000)


class ResetStates(SimpleExtension):
    def __init__(self, state_vars, **kwargs):
        super(ResetStates, self).__init__(**kwargs)

        self.f = theano.function(
            inputs=[], outputs=[],
            updates=[(v, v.zeros_like()) for v in state_vars])

    def do(self, which_callback, *args):
        self.f()

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
        sys.exit(1)
    model_name = sys.argv[-1]
    config = importlib.import_module('%s' % model_name)


    # Build datastream
    train_stream = datastream.setup_datastream('data/logcompil.txt',
                                               config.num_seqs,
                                               config.seq_len,
                                               config.seq_div_size)

    # Build model
    m = config.Model()
    m.pred.name = 'pred'

    # Train the model
    saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
    train_model(m, train_stream, dump_path=saveloc)


    # Define the model
    model = Model(m.sgd_cost)

    # IRC mode : just load the parameters and run an IRC server
    if '--irc' in sys.argv:
        try:
            extensions.append(FinishAfter(before_training=True, after_n_batches=1))
            print "Initializing main loop"
            main_loop.run()
            print "Jumping into IRC"
            irc.run_forever()
        except KeyboardInterrupt:
            pass
        sys.exit(0)

    # Train the model

    cg = ComputationGraph(m.sgd_cost)
    algorithm = GradientDescent(cost=m.sgd_cost,
                                step_rule=config.step_rule,
                                parameters=cg.parameters)

    algorithm.add_updates(m.states)

    monitor_vars = [v for p in m.monitor_vars for v in p]
    extensions = [
            TrainingDataMonitoring(
                monitor_vars,
                prefix='train', every_n_epochs=1),
            Printing(every_n_epochs=1, after_epoch=False),

            ResetStates([v for v, _ in m.states], after_epoch=True)
    ]
    if plot_avail:
        plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
        extensions.append(
            Plot(document='text_'+model_name,
                 channels=plot_channels,
                 server_url='http://localhost:5006',
                 every_n_epochs=1, after_epoch=False)
        )
    if config.save_freq is not None and dump_path is not None:
        extensions.append(
            SaveLoadParams(path=dump_path+'.pkl',
                           model=model,
                           before_training=True,
                           after_training=True,
                           after_epoch=False,
                           every_n_epochs=config.save_freq)
        )
    if config.sample_freq is not None:
        extensions.append(
            GenText(m, '\nalex\ttu crois ?\n',
                    config.sample_len, config.sample_temperature,
                    every_n_epochs=config.sample_freq,
                    after_epoch=False, before_training=True)
        )
    if config.on_irc:
        irc = IRCClientExt(m, config.sample_temperature,
                           server='clipper.ens.fr',
                           port=6667,
                           nick='frigo',
                           channels=['#frigotest', '#courssysteme'],
                           after_batch=True)
        irc.do('before_training')
        extensions.append(irc)

    if config.on_irc:
        irc = IRCClientExt(m, config.sample_temperature,
                           server='clipper.ens.fr',
                           port=6667,
                           nick='frigo',
                           channels=['#frigotest', '#courssysteme'],
                           after_batch=True)
        irc.do('before_training')
        extensions.append(irc)

	main_loop = MainLoop(
		model=model,
		data_stream=train_stream,
		algorithm=algorithm,
		extensions=extensions
	)
	main_loop.run()

    # IRC mode : just load the parameters and run an IRC server
    if '--irc' in sys.argv:
        try:
            extensions.append(FinishAfter(before_training=True, after_n_batches=1))
            print "Initializing main loop"
            main_loop.run()
            print "Jumping into IRC"
            irc.run_forever()
        except KeyboardInterrupt:
            pass
        sys.exit(0)








#  vim: set sts=4 ts=4 sw=4 tw=0 et :