summaryrefslogtreecommitdiff
path: root/irc.py
diff options
context:
space:
mode:
Diffstat (limited to 'irc.py')
-rw-r--r--irc.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/irc.py b/irc.py
new file mode 100644
index 0000000..f8ca125
--- /dev/null
+++ b/irc.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python2
+
+import logging
+import sys
+import importlib
+
+import theano
+
+from blocks.extensions import Printing, SimpleExtension, FinishAfter
+from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
+
+from blocks.graph import ComputationGraph
+from blocks.main_loop import MainLoop
+from blocks.model import Model
+from blocks.algorithms import GradientDescent
+
+try:
+ from blocks.extras.extensions.plot import Plot
+ plot_avail = False
+except ImportError:
+ plot_avail = False
+
+
+import datastream
+from paramsaveload import SaveLoadParams
+from gentext import GenText
+from ircext import IRCClientExt
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+sys.setrecursionlimit(500000)
+
+
+class ResetStates(SimpleExtension):
+ def __init__(self, state_vars, **kwargs):
+ super(ResetStates, self).__init__(**kwargs)
+
+ self.f = theano.function(
+ inputs=[], outputs=[],
+ updates=[(v, v.zeros_like()) for v in state_vars])
+
+ def do(self, which_callback, *args):
+ self.f()
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
+ sys.exit(1)
+ model_name = sys.argv[-1]
+ config = importlib.import_module('%s' % model_name)
+
+
+ # Build datastream
+ train_stream = datastream.setup_datastream('data/logcompil.txt',
+ config.num_seqs,
+ config.seq_len,
+ config.seq_div_size)
+
+ # Build model
+ m = config.Model()
+ m.pred.name = 'pred'
+
+ # Train the model
+ saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
+ train_model(m, train_stream, dump_path=saveloc)
+
+
+ # Define the model
+ model = Model(m.sgd_cost)
+
+ # IRC mode : just load the parameters and run an IRC server
+ if '--irc' in sys.argv:
+ try:
+ extensions.append(FinishAfter(before_training=True, after_n_batches=1))
+ print "Initializing main loop"
+ main_loop.run()
+ print "Jumping into IRC"
+ irc.run_forever()
+ except KeyboardInterrupt:
+ pass
+ sys.exit(0)
+
+ # Train the model
+
+ cg = ComputationGraph(m.sgd_cost)
+ algorithm = GradientDescent(cost=m.sgd_cost,
+ step_rule=config.step_rule,
+ parameters=cg.parameters)
+
+ algorithm.add_updates(m.states)
+
+ monitor_vars = [v for p in m.monitor_vars for v in p]
+ extensions = [
+ TrainingDataMonitoring(
+ monitor_vars,
+ prefix='train', every_n_epochs=1),
+ Printing(every_n_epochs=1, after_epoch=False),
+
+ ResetStates([v for v, _ in m.states], after_epoch=True)
+ ]
+ if plot_avail:
+ plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
+ extensions.append(
+ Plot(document='text_'+model_name,
+ channels=plot_channels,
+ server_url='http://localhost:5006',
+ every_n_epochs=1, after_epoch=False)
+ )
+ if config.save_freq is not None and dump_path is not None:
+ extensions.append(
+ SaveLoadParams(path=dump_path+'.pkl',
+ model=model,
+ before_training=True,
+ after_training=True,
+ after_epoch=False,
+ every_n_epochs=config.save_freq)
+ )
+ if config.sample_freq is not None:
+ extensions.append(
+ GenText(m, '\nalex\ttu crois ?\n',
+ config.sample_len, config.sample_temperature,
+ every_n_epochs=config.sample_freq,
+ after_epoch=False, before_training=True)
+ )
+ if config.on_irc:
+ irc = IRCClientExt(m, config.sample_temperature,
+ server='clipper.ens.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True)
+ irc.do('before_training')
+ extensions.append(irc)
+
+ if config.on_irc:
+ irc = IRCClientExt(m, config.sample_temperature,
+ server='clipper.ens.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True)
+ irc.do('before_training')
+ extensions.append(irc)
+
+ main_loop = MainLoop(
+ model=model,
+ data_stream=train_stream,
+ algorithm=algorithm,
+ extensions=extensions
+ )
+ main_loop.run()
+
+ # IRC mode : just load the parameters and run an IRC server
+ if '--irc' in sys.argv:
+ try:
+ extensions.append(FinishAfter(before_training=True, after_n_batches=1))
+ print "Initializing main loop"
+ main_loop.run()
+ print "Jumping into IRC"
+ irc.run_forever()
+ except KeyboardInterrupt:
+ pass
+ sys.exit(0)
+
+
+
+
+
+
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :