summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--irc.py173
-rwxr-xr-x[-rw-r--r--]ircbot.py (renamed from ircext.py)124
-rw-r--r--paramsaveload.py2
3 files changed, 75 insertions, 224 deletions
diff --git a/irc.py b/irc.py
deleted file mode 100644
index f8ca125..0000000
--- a/irc.py
+++ /dev/null
@@ -1,173 +0,0 @@
-#!/usr/bin/env python2
-
-import logging
-import sys
-import importlib
-
-import theano
-
-from blocks.extensions import Printing, SimpleExtension, FinishAfter
-from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-
-from blocks.graph import ComputationGraph
-from blocks.main_loop import MainLoop
-from blocks.model import Model
-from blocks.algorithms import GradientDescent
-
-try:
- from blocks.extras.extensions.plot import Plot
- plot_avail = False
-except ImportError:
- plot_avail = False
-
-
-import datastream
-from paramsaveload import SaveLoadParams
-from gentext import GenText
-from ircext import IRCClientExt
-
-logging.basicConfig(level='INFO')
-logger = logging.getLogger(__name__)
-
-sys.setrecursionlimit(500000)
-
-
-class ResetStates(SimpleExtension):
- def __init__(self, state_vars, **kwargs):
- super(ResetStates, self).__init__(**kwargs)
-
- self.f = theano.function(
- inputs=[], outputs=[],
- updates=[(v, v.zeros_like()) for v in state_vars])
-
- def do(self, which_callback, *args):
- self.f()
-
-if __name__ == "__main__":
- if len(sys.argv) < 2:
- print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
- sys.exit(1)
- model_name = sys.argv[-1]
- config = importlib.import_module('%s' % model_name)
-
-
- # Build datastream
- train_stream = datastream.setup_datastream('data/logcompil.txt',
- config.num_seqs,
- config.seq_len,
- config.seq_div_size)
-
- # Build model
- m = config.Model()
- m.pred.name = 'pred'
-
- # Train the model
- saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
- train_model(m, train_stream, dump_path=saveloc)
-
-
- # Define the model
- model = Model(m.sgd_cost)
-
- # IRC mode : just load the parameters and run an IRC server
- if '--irc' in sys.argv:
- try:
- extensions.append(FinishAfter(before_training=True, after_n_batches=1))
- print "Initializing main loop"
- main_loop.run()
- print "Jumping into IRC"
- irc.run_forever()
- except KeyboardInterrupt:
- pass
- sys.exit(0)
-
- # Train the model
-
- cg = ComputationGraph(m.sgd_cost)
- algorithm = GradientDescent(cost=m.sgd_cost,
- step_rule=config.step_rule,
- parameters=cg.parameters)
-
- algorithm.add_updates(m.states)
-
- monitor_vars = [v for p in m.monitor_vars for v in p]
- extensions = [
- TrainingDataMonitoring(
- monitor_vars,
- prefix='train', every_n_epochs=1),
- Printing(every_n_epochs=1, after_epoch=False),
-
- ResetStates([v for v, _ in m.states], after_epoch=True)
- ]
- if plot_avail:
- plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
- extensions.append(
- Plot(document='text_'+model_name,
- channels=plot_channels,
- server_url='http://localhost:5006',
- every_n_epochs=1, after_epoch=False)
- )
- if config.save_freq is not None and dump_path is not None:
- extensions.append(
- SaveLoadParams(path=dump_path+'.pkl',
- model=model,
- before_training=True,
- after_training=True,
- after_epoch=False,
- every_n_epochs=config.save_freq)
- )
- if config.sample_freq is not None:
- extensions.append(
- GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True)
- )
- if config.on_irc:
- irc = IRCClientExt(m, config.sample_temperature,
- server='clipper.ens.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True)
- irc.do('before_training')
- extensions.append(irc)
-
- if config.on_irc:
- irc = IRCClientExt(m, config.sample_temperature,
- server='clipper.ens.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True)
- irc.do('before_training')
- extensions.append(irc)
-
- main_loop = MainLoop(
- model=model,
- data_stream=train_stream,
- algorithm=algorithm,
- extensions=extensions
- )
- main_loop.run()
-
- # IRC mode : just load the parameters and run an IRC server
- if '--irc' in sys.argv:
- try:
- extensions.append(FinishAfter(before_training=True, after_n_batches=1))
- print "Initializing main loop"
- main_loop.run()
- print "Jumping into IRC"
- irc.run_forever()
- except KeyboardInterrupt:
- pass
- sys.exit(0)
-
-
-
-
-
-
-
-
-# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/ircext.py b/ircbot.py
index d8580ad..ac329eb 100644..100755
--- a/ircext.py
+++ b/ircbot.py
@@ -1,26 +1,60 @@
-from irc.client import SimpleIRCClient
+#!/usr/bin/env python2
import logging
+import sys
+import importlib
-import numpy
+from irc.client import SimpleIRCClient
+import numpy
import theano
from theano import tensor
-from blocks.extensions import SimpleExtension
+from blocks.model import Model
+
+import datastream
+from paramsaveload import SaveLoadParams
from blocks.graph import ComputationGraph
logging.basicConfig(level='INFO')
-logger = logging.getLogger('irc_ext')
+logger = logging.getLogger(__name__)
+
class IRCClient(SimpleIRCClient):
- def __init__(self, chans, nick):
+ def __init__(self, model, sample_temperature, server, port, nick, channels, saveload):
super(IRCClient, self).__init__()
+ out = model.out[:, -1, :] / numpy.float32(sample_temperature)
+ prob = tensor.nnet.softmax(out)
+
+ cg = ComputationGraph([prob])
+ assert(len(cg.inputs) == 1)
+ assert(cg.inputs[0].name == 'bytes')
+
+ # channel functions & state
+ chfun = {}
+ for ch in channels + ['']:
+ logger.info("Building theano function for channel '%s'"%ch)
+ state_vars = [theano.shared(v[0:1, :].zeros_like().eval(), v.name+'-'+ch)
+ for v, _ in model.states]
+ givens = [(v, x) for (v, _), x in zip(model.states, state_vars)]
+ updates= [(x, upd) for x, (_, upd) in zip(state_vars, model.states)]
+
+ pred = theano.function(inputs=cg.inputs, outputs=[prob],
+ givens=givens, updates=updates)
+ reset_states = theano.function(inputs=[], outputs=[],
+ updates=[(v, v.zeros_like()) for v in state_vars])
+ chfun[ch] = (pred, reset_states)
+
+ self.saveload = saveload
+
+ self.chfuns = chfun
+
self.chans = chans
self.nick = nick
self.server = None
+
def on_welcome(self, server, ev):
logger.info("Welcomed to " + repr(server))
for ch in self.chans:
@@ -37,7 +71,7 @@ class IRCClient(SimpleIRCClient):
s = ''
while True:
prob = prob / 1.00001
- pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0]
+ pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0].astype('int16')
s = s + chr(int(pred))
@@ -64,8 +98,8 @@ class IRCClient(SimpleIRCClient):
rep = None
- if chan in self.chans:
- pred_f, _ = self.chans[chan]
+ if chan in self.chfuns:
+ pred_f, _ = self.chfuns[chan]
if s0[-2:] == '^I':
prob, = pred_f(self.str2data(s0[:-2]))
rep = s0[:-2] + self.pred_until(pred_f, prob)
@@ -78,56 +112,46 @@ class IRCClient(SimpleIRCClient):
prob, = pred_f(self.str2data(nick+': '))
rep = nick + ': ' + self.pred_until(pred_f, prob)
else:
- pass
-
+ logger.warn('Recieved message on unknown channel: %s'%chan)
+
if rep != None:
self.privmsg(chan, rep)
-class IRCClientExt(SimpleExtension):
- def __init__(self, model, sample_temperature, server, port, nick, channels, **kwargs):
- super(IRCClientExt, self).__init__(**kwargs)
- # model output
- out = model.out[:, -1, :] / numpy.float32(sample_temperature)
- prob = tensor.nnet.softmax(out)
-
- cg = ComputationGraph([prob])
- assert(len(cg.inputs) == 1)
- assert(cg.inputs[0].name == 'bytes')
- # channel functions & state
- chfun = {}
- for ch in channels + ['']:
- state_vars = [theano.shared(v[0:1, :].zeros_like().eval(), v.name+'-'+ch)
- for v, _ in model.states]
- givens = [(v, x) for (v, _), x in zip(model.states, state_vars)]
- updates= [(x, upd) for x, (_, upd) in zip(state_vars, model.states)]
-
- pred = theano.function(inputs=cg.inputs, outputs=[prob],
- givens=givens, updates=updates)
- reset_states = theano.function(inputs=[], outputs=[],
- updates=[(v, v.zeros_like()) for v in state_vars])
- chfun[ch] = (pred, reset_states)
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
+ sys.exit(1)
+ model_name = sys.argv[-1]
+ config = importlib.import_module('.%s' % model_name, 'config')
- self.irc = IRCClient(chfun, nick)
- self.irc.connect(server, port, nick)
+ # Build model
+ logger.info('Building model...')
+ m = config.Model(config)
- def __getstate__(self):
- state = dict(self.__dict__)
- del state['irc']
- return state
+ # Define the computation graph && load parameters
+ logger.info('Building computation graph...')
+ dump_path = 'params/%s-use_on_irc.pkl' % model_name
+ saveload = SaveLoadParams(path=dump_path,
+ model=Model(m.sgd_cost))
+ saveload.do_load()
- def __setstate__(self, state):
- irc = self.irc
- self.__dict__.update(state)
- self.irc = irc
+ # Build IRC client
+ server = 'clipper.ens.fr'
+ port = 6667
+ nick = 'frigo'
+ chans = ['#frigotest', '#courssysteme']
- def do(self, which_callback, *args):
- logger.info('Polling...')
- self.irc.reactor.process_once()
-
- def run_forever(self):
- self.irc.reactor.process_forever()
+ irc = IRCClient(model=m,
+ sample_temperature=config.sample_temperature,
+ server=server,
+ port=port,
+ nick=nick,
+ channels=chans,
+ saveload=saveload)
+ irc.connect(server, port, nick)
+ irc.reactor.process_forever()
-# vim: set sts=4 ts=4 sw=4 tw=0 et :
+# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/paramsaveload.py b/paramsaveload.py
index 9c05926..345cfbb 100644
--- a/paramsaveload.py
+++ b/paramsaveload.py
@@ -27,7 +27,7 @@ class SaveLoadParams(SimpleExtension):
logger.info('Loading parameters from %s...'%self.path)
self.model.set_parameter_values(cPickle.load(f))
except IOError:
- pass
+ logger.warn('No previous parameters found at %s...'%self.path)
def do(self, which_callback, *args):
if which_callback == 'before_training':