1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
|
from irc.client import SimpleIRCClient
import logging
import numpy
import theano
from theano import tensor
from blocks.extensions import SimpleExtension
from blocks.graph import ComputationGraph
logging.basicConfig(level='INFO')
logger = logging.getLogger('irc_ext')
class IRCClient(SimpleIRCClient):
def __init__(self, chans, nick):
super(IRCClient, self).__init__()
self.chans = chans
self.nick = nick
self.server = None
def on_welcome(self, server, ev):
logger.info("Welcomed to " + repr(server))
for ch in self.chans:
if ch != '' and ch[0] == '#':
server.join(ch)
def on_join(self, server, ev):
self.server = server
def str2data(self, s):
return numpy.array([ord(x) for x in s], dtype='int16')[None, :]
def pred_until(self, pred_f, prob, delim='\n'):
s = ''
while True:
prob = prob / 1.00001
pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0]
s = s + chr(int(pred))
prob, = pred_f(pred[None, None])
if s[-1] == delim:
break
return s[:-1]
def privmsg(self, chan, msg):
if len(msg) > 500:
msg = 'blip bloup'
logger.info("%s >> %s" % (chan, msg))
self.server.privmsg(chan, msg.decode('utf-8', 'ignore'))
def on_pubmsg(self, server, ev):
chan = ev.target.encode('utf-8')
nick = ev.source.split('!')[0].encode('utf-8')
msg = ev.arguments[0].encode('utf-8')
logger.info("%s <%s> %s" % (chan, nick, msg))
s0 = nick+'\t'+msg
rep = None
if chan in self.chans:
pred_f, _ = self.chans[chan]
if s0[-2:] == '^I':
prob, = pred_f(self.str2data(s0[:-2]))
rep = s0[:-2] + self.pred_until(pred_f, prob)
rep = rep.split('\t', 1)[-1]
else:
# feed phrase to bot
prob, = pred_f(self.str2data(s0+'\n'))
if self.nick in msg:
self.pred_until(pred_f, prob, '\t')
prob, = pred_f(self.str2data(nick+': '))
rep = nick + ': ' + self.pred_until(pred_f, prob)
else:
pass
if rep != None:
self.privmsg(chan, rep)
class IRCClientExt(SimpleExtension):
def __init__(self, model, sample_temperature, server, port, nick, channels, **kwargs):
super(IRCClientExt, self).__init__(**kwargs)
# model output
out = model.out[:, -1, :] / numpy.float32(sample_temperature)
prob = tensor.nnet.softmax(out)
cg = ComputationGraph([prob])
assert(len(cg.inputs) == 1)
assert(cg.inputs[0].name == 'bytes')
# channel functions & state
chfun = {}
for ch in channels + ['']:
state_vars = [theano.shared(v[0:1, :].zeros_like().eval(), v.name+'-'+ch)
for v, _ in model.states]
givens = [(v, x) for (v, _), x in zip(model.states, state_vars)]
updates= [(x, upd) for x, (_, upd) in zip(state_vars, model.states)]
pred = theano.function(inputs=cg.inputs, outputs=[prob],
givens=givens, updates=updates)
reset_states = theano.function(inputs=[], outputs=[],
updates=[(v, v.zeros_like()) for v in state_vars])
chfun[ch] = (pred, reset_states)
self.irc = IRCClient(chfun, nick)
self.irc.connect(server, port, nick)
def __getstate__(self):
state = dict(self.__dict__)
del state['irc']
return state
def __setstate__(self, state):
irc = self.irc
self.__dict__.update(state)
self.irc = irc
def do(self, which_callback, *args):
logger.info('Polling...')
self.irc.reactor.process_once()
def run_forever(self):
self.irc.reactor.process_forever()
# vim: set sts=4 ts=4 sw=4 tw=0 et :
|