diff options
-rw-r--r-- | REMARKS | 1 | ||||
-rw-r--r-- | config/lstm-frigo-irc.py | 3 | ||||
-rw-r--r-- | config/lstm-xreg.py | 4 | ||||
-rwxr-xr-x | ircbot.py | 16 |
4 files changed, 16 insertions, 8 deletions
@@ -1 +1,2 @@ - lstm-xreg-relu : does not converge at all, cost is stuck around 3.6 and error rate is 86% +- lstm-xreg : converges very badly diff --git a/config/lstm-frigo-irc.py b/config/lstm-frigo-irc.py index a0f5b5c..04468c9 100644 --- a/config/lstm-frigo-irc.py +++ b/config/lstm-frigo-irc.py @@ -34,7 +34,8 @@ monitor_freq = 100 save_freq = 100 # used for sample generation and IRC mode -sample_temperature = 0.7 #0.5 +#sample_temperature = 0.7 #0.5 +sample_temperature = 0.9 #0.5 # do we want to generate samples at times during training? sample_len = 1000 diff --git a/config/lstm-xreg.py b/config/lstm-xreg.py index f8c5094..a5d5b46 100644 --- a/config/lstm-xreg.py +++ b/config/lstm-xreg.py @@ -14,10 +14,10 @@ seq_div_size = 200 layers = [ {'dim': 1024, - 'xreg': (768, 0.1, 10, 10, 10, 2) + 'xreg': (768, 0.1, 10, 20, 10, 0) }, {'dim': 1024, - 'xreg': (768, 0.1, 10, 10, 10, 5) + 'xreg': (768, 0.1, 10, 20, 10, 0) }, {'dim': 1024, }, @@ -16,6 +16,8 @@ import datastream from paramsaveload import SaveLoadParams from blocks.graph import ComputationGraph +import random + logging.basicConfig(level='INFO') logger = logging.getLogger(__name__) @@ -67,11 +69,13 @@ class IRCClient(SimpleIRCClient): def str2data(self, s): return numpy.array([ord(x) for x in s], dtype='int16')[None, :] - def pred_until(self, pred_f, prob, delim='\n'): + def pred_until(self, pred_f, prob, delim='\n', forbid_first=None): s = '' while True: prob = prob / 1.00001 pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0].astype('int16') + if forbid_first is not None and s == '' and int(pred) == forbid_first: + continue # try again s = s + chr(int(pred)) @@ -108,9 +112,11 @@ class IRCClient(SimpleIRCClient): # feed phrase to bot prob, = pred_f(self.str2data(s0+'\n')) if any(x in msg.lower() for x in [self.nick, 'frigal']): - self.pred_until(pred_f, prob, '\t') - prob, = pred_f(self.str2data(nick+': ')) - rep = nick + ': ' + self.pred_until(pred_f, prob) + if random.uniform(0, 1) < 0.3: + fromnick = self.pred_until(pred_f, prob, '\t', forbid_first=ord(nick[0])) + logger.info("from '%s'"%fromnick) + prob, = pred_f(self.str2data(nick+': ')) + rep = nick + ': ' + self.pred_until(pred_f, prob) else: logger.warn('Recieved message on unknown channel: %s'%chan) @@ -141,7 +147,7 @@ if __name__ == "__main__": server = 'clipper.ens.fr' port = 6667 nick = 'frigo' - chans = ['#frigotest', '#courssysteme'] + chans = ['#frigotest'] irc = IRCClient(model=m, sample_temperature=config.sample_temperature, |