diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 09:23:47 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 09:23:47 -0400 |
commit | e91e14e894196642532c0b7be50b01c1354ad702 (patch) | |
tree | fa87266d998dceb1df8e37b30c71fbccc291d25f /ircext.py | |
parent | 211c2272c544ab0bbf7b87b374736a71c790ac8e (diff) | |
download | text-rnn-e91e14e894196642532c0b7be50b01c1354ad702.tar.gz text-rnn-e91e14e894196642532c0b7be50b01c1354ad702.zip |
Connect it to IRC ; add GFGRU model
Diffstat (limited to 'ircext.py')
-rw-r--r-- | ircext.py | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/ircext.py b/ircext.py new file mode 100644 index 0000000..3e38ac2 --- /dev/null +++ b/ircext.py @@ -0,0 +1,125 @@ +from irc.client import SimpleIRCClient + +import logging + +import numpy + +import theano +from theano import tensor + +from blocks.extensions import SimpleExtension +from blocks.graph import ComputationGraph + +logging.basicConfig(level='INFO') +logger = logging.getLogger('irc_ext') + +class IRCClient(SimpleIRCClient): + def __init__(self, chans, nick): + super(IRCClient, self).__init__() + + self.chans = chans + self.nick = nick + self.server = None + + def on_welcome(self, server, ev): + logger.info("Welcomed to " + repr(server)) + for ch in self.chans: + if ch != '' and ch[0] == '#': + server.join(ch) + + def on_join(self, server, ev): + self.server = server + + def pred_line(self, pred_f, prob): + s = '' + while True: + prob = prob / 1.00001 + pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0] + + s = s + chr(int(pred)) + + prob, = pred_f(pred[None, None]) + + if s[-1] == '\n': + break + return s[:-1] + + def privmsg(self, chan, msg): + logger.info("%s >> %s" % (chan, msg)) + self.server.privmsg(chan, msg.decode('utf-8', 'ignore')) + + def on_pubmsg(self, server, ev): + chan = ev.target.encode('utf-8') + nick = ev.source.split('!')[0].encode('utf-8') + msg = ev.arguments[0].encode('utf-8') + + logger.info("%s <%s> %s" % (chan, nick, msg)) + + s0 = nick+'\t'+msg + + rep = None + + if chan in self.chans: + pred_f, _ = self.chans[chan] + if s0[-2:] == '^I': + prob, = pred_f(numpy.array([ord(x) for x in s0[:-2]], dtype='int16')[None, :]) + s = self.pred_line(pred_f, prob) + rep = s0[:-2] + s + else: + # feed phrase to bot + prob, = pred_f(numpy.array([ord(x) for x in s0+'\n'], dtype='int16')[None, :]) + if msg[:len(self.nick)+1] == self.nick+':': + #TODO: make it so that it predicts a message beginning by 'nick: ' + # (ie it responds to the person who pinged it) + rep = self.pred_line(pred_f, prob) + else: + pass + + if rep != None: + rep = rep.split('\t', 1)[1] + self.privmsg(chan, rep) + +class IRCClientExt(SimpleExtension): + def __init__(self, model, sample_temperature, server, port, nick, channels, **kwargs): + super(IRCClientExt, self).__init__(**kwargs) + + # model output + out = model.out[:, -1, :] / numpy.float32(sample_temperature) + prob = tensor.nnet.softmax(out) + + cg = ComputationGraph([prob]) + assert(len(cg.inputs) == 1) + assert(cg.inputs[0].name == 'bytes') + + # channel functions & state + chfun = {} + for ch in channels + ['']: + state_vars = [theano.shared(v[0:1, :].zeros_like().eval(), v.name+'-'+ch) + for v, _ in model.states] + givens = [(v, x) for (v, _), x in zip(model.states, state_vars)] + updates= [(x, upd) for x, (_, upd) in zip(state_vars, model.states)] + + pred = theano.function(inputs=cg.inputs, outputs=[prob], + givens=givens, updates=updates) + reset_states = theano.function(inputs=[], outputs=[], + updates=[(v, v.zeros_like()) for v in state_vars]) + chfun[ch] = (pred, reset_states) + + self.irc = IRCClient(chfun, nick) + self.irc.connect(server, port, nick) + + def __getstate__(self): + state = dict(self.__dict__) + del state['irc'] + return state + + def __setstate__(self, state): + irc = self.irc + self.__dict__.update(state) + self.irc = irc + + def do(self, which_callback, *args): + logger.info('Polling...') + self.irc.reactor.process_once() + + |