summaryrefslogtreecommitdiff
path: root/irc.py
diff options
context:
space:
mode:
Diffstat (limited to 'irc.py')
-rw-r--r--irc.py173
1 files changed, 0 insertions, 173 deletions
diff --git a/irc.py b/irc.py
deleted file mode 100644
index f8ca125..0000000
--- a/irc.py
+++ /dev/null
@@ -1,173 +0,0 @@
-#!/usr/bin/env python2
-
-import logging
-import sys
-import importlib
-
-import theano
-
-from blocks.extensions import Printing, SimpleExtension, FinishAfter
-from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-
-from blocks.graph import ComputationGraph
-from blocks.main_loop import MainLoop
-from blocks.model import Model
-from blocks.algorithms import GradientDescent
-
-try:
- from blocks.extras.extensions.plot import Plot
- plot_avail = False
-except ImportError:
- plot_avail = False
-
-
-import datastream
-from paramsaveload import SaveLoadParams
-from gentext import GenText
-from ircext import IRCClientExt
-
-logging.basicConfig(level='INFO')
-logger = logging.getLogger(__name__)
-
-sys.setrecursionlimit(500000)
-
-
-class ResetStates(SimpleExtension):
- def __init__(self, state_vars, **kwargs):
- super(ResetStates, self).__init__(**kwargs)
-
- self.f = theano.function(
- inputs=[], outputs=[],
- updates=[(v, v.zeros_like()) for v in state_vars])
-
- def do(self, which_callback, *args):
- self.f()
-
-if __name__ == "__main__":
- if len(sys.argv) < 2:
- print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
- sys.exit(1)
- model_name = sys.argv[-1]
- config = importlib.import_module('%s' % model_name)
-
-
- # Build datastream
- train_stream = datastream.setup_datastream('data/logcompil.txt',
- config.num_seqs,
- config.seq_len,
- config.seq_div_size)
-
- # Build model
- m = config.Model()
- m.pred.name = 'pred'
-
- # Train the model
- saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
- train_model(m, train_stream, dump_path=saveloc)
-
-
- # Define the model
- model = Model(m.sgd_cost)
-
- # IRC mode : just load the parameters and run an IRC server
- if '--irc' in sys.argv:
- try:
- extensions.append(FinishAfter(before_training=True, after_n_batches=1))
- print "Initializing main loop"
- main_loop.run()
- print "Jumping into IRC"
- irc.run_forever()
- except KeyboardInterrupt:
- pass
- sys.exit(0)
-
- # Train the model
-
- cg = ComputationGraph(m.sgd_cost)
- algorithm = GradientDescent(cost=m.sgd_cost,
- step_rule=config.step_rule,
- parameters=cg.parameters)
-
- algorithm.add_updates(m.states)
-
- monitor_vars = [v for p in m.monitor_vars for v in p]
- extensions = [
- TrainingDataMonitoring(
- monitor_vars,
- prefix='train', every_n_epochs=1),
- Printing(every_n_epochs=1, after_epoch=False),
-
- ResetStates([v for v, _ in m.states], after_epoch=True)
- ]
- if plot_avail:
- plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
- extensions.append(
- Plot(document='text_'+model_name,
- channels=plot_channels,
- server_url='http://localhost:5006',
- every_n_epochs=1, after_epoch=False)
- )
- if config.save_freq is not None and dump_path is not None:
- extensions.append(
- SaveLoadParams(path=dump_path+'.pkl',
- model=model,
- before_training=True,
- after_training=True,
- after_epoch=False,
- every_n_epochs=config.save_freq)
- )
- if config.sample_freq is not None:
- extensions.append(
- GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True)
- )
- if config.on_irc:
- irc = IRCClientExt(m, config.sample_temperature,
- server='clipper.ens.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True)
- irc.do('before_training')
- extensions.append(irc)
-
- if config.on_irc:
- irc = IRCClientExt(m, config.sample_temperature,
- server='clipper.ens.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True)
- irc.do('before_training')
- extensions.append(irc)
-
- main_loop = MainLoop(
- model=model,
- data_stream=train_stream,
- algorithm=algorithm,
- extensions=extensions
- )
- main_loop.run()
-
- # IRC mode : just load the parameters and run an IRC server
- if '--irc' in sys.argv:
- try:
- extensions.append(FinishAfter(before_training=True, after_n_batches=1))
- print "Initializing main loop"
- main_loop.run()
- print "Jumping into IRC"
- irc.run_forever()
- except KeyboardInterrupt:
- pass
- sys.exit(0)
-
-
-
-
-
-
-
-
-# vim: set sts=4 ts=4 sw=4 tw=0 et :