diff options
Diffstat (limited to 'ircbot.py')
-rwxr-xr-x | ircbot.py | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/ircbot.py b/ircbot.py new file mode 100755 index 0000000..ac329eb --- /dev/null +++ b/ircbot.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python2 + +import logging +import sys +import importlib + +from irc.client import SimpleIRCClient + +import numpy +import theano +from theano import tensor + +from blocks.model import Model + +import datastream +from paramsaveload import SaveLoadParams +from blocks.graph import ComputationGraph + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + + +class IRCClient(SimpleIRCClient): + def __init__(self, model, sample_temperature, server, port, nick, channels, saveload): + super(IRCClient, self).__init__() + + out = model.out[:, -1, :] / numpy.float32(sample_temperature) + prob = tensor.nnet.softmax(out) + + cg = ComputationGraph([prob]) + assert(len(cg.inputs) == 1) + assert(cg.inputs[0].name == 'bytes') + + # channel functions & state + chfun = {} + for ch in channels + ['']: + logger.info("Building theano function for channel '%s'"%ch) + state_vars = [theano.shared(v[0:1, :].zeros_like().eval(), v.name+'-'+ch) + for v, _ in model.states] + givens = [(v, x) for (v, _), x in zip(model.states, state_vars)] + updates= [(x, upd) for x, (_, upd) in zip(state_vars, model.states)] + + pred = theano.function(inputs=cg.inputs, outputs=[prob], + givens=givens, updates=updates) + reset_states = theano.function(inputs=[], outputs=[], + updates=[(v, v.zeros_like()) for v in state_vars]) + chfun[ch] = (pred, reset_states) + + self.saveload = saveload + + self.chfuns = chfun + + self.chans = chans + self.nick = nick + self.server = None + + + def on_welcome(self, server, ev): + logger.info("Welcomed to " + repr(server)) + for ch in self.chans: + if ch != '' and ch[0] == '#': + server.join(ch) + + def on_join(self, server, ev): + self.server = server + + def str2data(self, s): + return numpy.array([ord(x) for x in s], dtype='int16')[None, :] + + def pred_until(self, pred_f, prob, delim='\n'): + s = '' + while True: + prob = prob / 1.00001 + pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0].astype('int16') + + s = s + chr(int(pred)) + + prob, = pred_f(pred[None, None]) + + if s[-1] == delim: + break + return s[:-1] + + def privmsg(self, chan, msg): + if len(msg) > 500: + msg = 'blip bloup' + logger.info("%s >> %s" % (chan, msg)) + self.server.privmsg(chan, msg.decode('utf-8', 'ignore')) + + def on_pubmsg(self, server, ev): + chan = ev.target.encode('utf-8') + nick = ev.source.split('!')[0].encode('utf-8') + msg = ev.arguments[0].encode('utf-8') + + logger.info("%s <%s> %s" % (chan, nick, msg)) + + s0 = nick+'\t'+msg + + rep = None + + if chan in self.chfuns: + pred_f, _ = self.chfuns[chan] + if s0[-2:] == '^I': + prob, = pred_f(self.str2data(s0[:-2])) + rep = s0[:-2] + self.pred_until(pred_f, prob) + rep = rep.split('\t', 1)[-1] + else: + # feed phrase to bot + prob, = pred_f(self.str2data(s0+'\n')) + if self.nick in msg: + self.pred_until(pred_f, prob, '\t') + prob, = pred_f(self.str2data(nick+': ')) + rep = nick + ': ' + self.pred_until(pred_f, prob) + else: + logger.warn('Recieved message on unknown channel: %s'%chan) + + if rep != None: + self.privmsg(chan, rep) + + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[-1] + config = importlib.import_module('.%s' % model_name, 'config') + + # Build model + logger.info('Building model...') + m = config.Model(config) + + # Define the computation graph && load parameters + logger.info('Building computation graph...') + dump_path = 'params/%s-use_on_irc.pkl' % model_name + saveload = SaveLoadParams(path=dump_path, + model=Model(m.sgd_cost)) + saveload.do_load() + + # Build IRC client + server = 'clipper.ens.fr' + port = 6667 + nick = 'frigo' + chans = ['#frigotest', '#courssysteme'] + + irc = IRCClient(model=m, + sample_temperature=config.sample_temperature, + server=server, + port=port, + nick=nick, + channels=chans, + saveload=saveload) + irc.connect(server, port, nick) + irc.reactor.process_forever() + + +# vim: set sts=4 ts=4 sw=4 tw=0 et : |