summaryrefslogtreecommitdiff
path: root/ircext.py
blob: 1af2ba87419874777c772507d3790e5264994eea (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from irc.client import SimpleIRCClient

import logging

import numpy

import theano
from theano import tensor

from blocks.extensions import SimpleExtension
from blocks.graph import ComputationGraph

logging.basicConfig(level='INFO')
logger = logging.getLogger('irc_ext')

class IRCClient(SimpleIRCClient):
    def __init__(self, chans, nick):
        super(IRCClient, self).__init__()

        self.chans = chans
        self.nick = nick
        self.server = None

    def on_welcome(self, server, ev):
        logger.info("Welcomed to " + repr(server))
        for ch in self.chans:
            if ch != '' and ch[0] == '#':
                server.join(ch)

    def on_join(self, server, ev):
        self.server = server

    def str2data(self, s):
        return numpy.array([ord(x) for x in s], dtype='int16')[None, :]

    def pred_until(self, pred_f, prob, delim='\n'):
        s = ''
        while True:
            prob = prob / 1.00001
            pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0]

            s = s + chr(int(pred))

            prob, = pred_f(pred[None, None])

            if s[-1] == delim:
                break
        return s[:-1]

    def privmsg(self, chan, msg):
        if len(msg) > 500:
            msg = 'blip bloup'
        logger.info("%s >> %s" % (chan, msg))
        self.server.privmsg(chan, msg.decode('utf-8', 'ignore'))

    def on_pubmsg(self, server, ev):
        chan = ev.target.encode('utf-8')
        nick = ev.source.split('!')[0].encode('utf-8')
        msg = ev.arguments[0].encode('utf-8')

        logger.info("%s <%s> %s" % (chan, nick, msg))

        s0 = nick+'\t'+msg

        rep = None

        if chan in self.chans:
            pred_f, _ = self.chans[chan]
            if s0[-2:] == '^I':
                prob, = pred_f(self.str2data(s0[:-2]))
                rep = s0[:-2] + self.pred_until(pred_f, prob)
                rep = rep.split('\t', 1)[-1]
            else:
                # feed phrase to bot
                prob, = pred_f(self.str2data(s0+'\n'))
                if self.nick in msg:
                    self.pred_until(pred_f, prob, '\t') 
                    prob, = pred_f(self.str2data(nick+': '))
                    rep = nick + ': ' + self.pred_until(pred_f, prob)
        else:
            pass

        if rep != None:
            self.privmsg(chan, rep)

class IRCClientExt(SimpleExtension):
    def __init__(self, model, sample_temperature, server, port, nick, channels, **kwargs):
        super(IRCClientExt, self).__init__(**kwargs)

        # model output
        out = model.out[:, -1, :] / numpy.float32(sample_temperature)
        prob = tensor.nnet.softmax(out)

        cg = ComputationGraph([prob])
        assert(len(cg.inputs) == 1)
        assert(cg.inputs[0].name == 'bytes')

        # channel functions & state
        chfun = {}
        for ch in channels + ['']:
            state_vars = [theano.shared(v[0:1, :].zeros_like().eval(), v.name+'-'+ch)
                                    for v, _ in model.states]
            givens = [(v, x) for (v, _), x in zip(model.states, state_vars)]
            updates= [(x, upd) for x, (_, upd) in zip(state_vars, model.states)] 

            pred = theano.function(inputs=cg.inputs, outputs=[prob],
                                   givens=givens, updates=updates)
            reset_states = theano.function(inputs=[], outputs=[],
                                           updates=[(v, v.zeros_like()) for v in state_vars])
            chfun[ch] = (pred, reset_states)

        self.irc = IRCClient(chfun, nick)
        self.irc.connect(server, port, nick)

    def __getstate__(self):
        state = dict(self.__dict__)
        del state['irc']
        return state

    def __setstate__(self, state):
        irc = self.irc
        self.__dict__.update(state)
        self.irc = irc

    def do(self, which_callback, *args):
        logger.info('Polling...')
        self.irc.reactor.process_once()