summaryrefslogtreecommitdiff
path: root/ircbot.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2016-03-08 20:21:28 +0100
committerAlex Auvolat <alex@adnab.me>2016-03-08 20:21:28 +0100
commitbb9ebdeee88409a209d1bcc04a374e4b7d7e13d2 (patch)
tree1fc4cbafe8e443e40101551a809f2473d10ba22a /ircbot.py
parent0cc79a9d41d3829a75e1d0f953d706e933e9d194 (diff)
downloadtext-rnn-bb9ebdeee88409a209d1bcc04a374e4b7d7e13d2.tar.gz
text-rnn-bb9ebdeee88409a209d1bcc04a374e4b7d7e13d2.zip
Fix IRC mode
Diffstat (limited to 'ircbot.py')
-rwxr-xr-xircbot.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/ircbot.py b/ircbot.py
new file mode 100755
index 0000000..ac329eb
--- /dev/null
+++ b/ircbot.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python2
+
+import logging
+import sys
+import importlib
+
+from irc.client import SimpleIRCClient
+
+import numpy
+import theano
+from theano import tensor
+
+from blocks.model import Model
+
+import datastream
+from paramsaveload import SaveLoadParams
+from blocks.graph import ComputationGraph
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+
+class IRCClient(SimpleIRCClient):
+ def __init__(self, model, sample_temperature, server, port, nick, channels, saveload):
+ super(IRCClient, self).__init__()
+
+ out = model.out[:, -1, :] / numpy.float32(sample_temperature)
+ prob = tensor.nnet.softmax(out)
+
+ cg = ComputationGraph([prob])
+ assert(len(cg.inputs) == 1)
+ assert(cg.inputs[0].name == 'bytes')
+
+ # channel functions & state
+ chfun = {}
+ for ch in channels + ['']:
+ logger.info("Building theano function for channel '%s'"%ch)
+ state_vars = [theano.shared(v[0:1, :].zeros_like().eval(), v.name+'-'+ch)
+ for v, _ in model.states]
+ givens = [(v, x) for (v, _), x in zip(model.states, state_vars)]
+ updates= [(x, upd) for x, (_, upd) in zip(state_vars, model.states)]
+
+ pred = theano.function(inputs=cg.inputs, outputs=[prob],
+ givens=givens, updates=updates)
+ reset_states = theano.function(inputs=[], outputs=[],
+ updates=[(v, v.zeros_like()) for v in state_vars])
+ chfun[ch] = (pred, reset_states)
+
+ self.saveload = saveload
+
+ self.chfuns = chfun
+
+ self.chans = chans
+ self.nick = nick
+ self.server = None
+
+
+ def on_welcome(self, server, ev):
+ logger.info("Welcomed to " + repr(server))
+ for ch in self.chans:
+ if ch != '' and ch[0] == '#':
+ server.join(ch)
+
+ def on_join(self, server, ev):
+ self.server = server
+
+ def str2data(self, s):
+ return numpy.array([ord(x) for x in s], dtype='int16')[None, :]
+
+ def pred_until(self, pred_f, prob, delim='\n'):
+ s = ''
+ while True:
+ prob = prob / 1.00001
+ pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0].astype('int16')
+
+ s = s + chr(int(pred))
+
+ prob, = pred_f(pred[None, None])
+
+ if s[-1] == delim:
+ break
+ return s[:-1]
+
+ def privmsg(self, chan, msg):
+ if len(msg) > 500:
+ msg = 'blip bloup'
+ logger.info("%s >> %s" % (chan, msg))
+ self.server.privmsg(chan, msg.decode('utf-8', 'ignore'))
+
+ def on_pubmsg(self, server, ev):
+ chan = ev.target.encode('utf-8')
+ nick = ev.source.split('!')[0].encode('utf-8')
+ msg = ev.arguments[0].encode('utf-8')
+
+ logger.info("%s <%s> %s" % (chan, nick, msg))
+
+ s0 = nick+'\t'+msg
+
+ rep = None
+
+ if chan in self.chfuns:
+ pred_f, _ = self.chfuns[chan]
+ if s0[-2:] == '^I':
+ prob, = pred_f(self.str2data(s0[:-2]))
+ rep = s0[:-2] + self.pred_until(pred_f, prob)
+ rep = rep.split('\t', 1)[-1]
+ else:
+ # feed phrase to bot
+ prob, = pred_f(self.str2data(s0+'\n'))
+ if self.nick in msg:
+ self.pred_until(pred_f, prob, '\t')
+ prob, = pred_f(self.str2data(nick+': '))
+ rep = nick + ': ' + self.pred_until(pred_f, prob)
+ else:
+ logger.warn('Recieved message on unknown channel: %s'%chan)
+
+ if rep != None:
+ self.privmsg(chan, rep)
+
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
+ sys.exit(1)
+ model_name = sys.argv[-1]
+ config = importlib.import_module('.%s' % model_name, 'config')
+
+ # Build model
+ logger.info('Building model...')
+ m = config.Model(config)
+
+ # Define the computation graph && load parameters
+ logger.info('Building computation graph...')
+ dump_path = 'params/%s-use_on_irc.pkl' % model_name
+ saveload = SaveLoadParams(path=dump_path,
+ model=Model(m.sgd_cost))
+ saveload.do_load()
+
+ # Build IRC client
+ server = 'clipper.ens.fr'
+ port = 6667
+ nick = 'frigo'
+ chans = ['#frigotest', '#courssysteme']
+
+ irc = IRCClient(model=m,
+ sample_temperature=config.sample_temperature,
+ server=server,
+ port=port,
+ nick=nick,
+ channels=chans,
+ saveload=saveload)
+ irc.connect(server, port, nick)
+ irc.reactor.process_forever()
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :