summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--REMARKS1
-rw-r--r--config/lstm-frigo-irc.py3
-rw-r--r--config/lstm-xreg.py4
-rwxr-xr-xircbot.py16
4 files changed, 16 insertions, 8 deletions
diff --git a/REMARKS b/REMARKS
index c44ca5f..e27cfdf 100644
--- a/REMARKS
+++ b/REMARKS
@@ -1 +1,2 @@
- lstm-xreg-relu : does not converge at all, cost is stuck around 3.6 and error rate is 86%
+- lstm-xreg : converges very badly
diff --git a/config/lstm-frigo-irc.py b/config/lstm-frigo-irc.py
index a0f5b5c..04468c9 100644
--- a/config/lstm-frigo-irc.py
+++ b/config/lstm-frigo-irc.py
@@ -34,7 +34,8 @@ monitor_freq = 100
save_freq = 100
# used for sample generation and IRC mode
-sample_temperature = 0.7 #0.5
+#sample_temperature = 0.7 #0.5
+sample_temperature = 0.9 #0.5
# do we want to generate samples at times during training?
sample_len = 1000
diff --git a/config/lstm-xreg.py b/config/lstm-xreg.py
index f8c5094..a5d5b46 100644
--- a/config/lstm-xreg.py
+++ b/config/lstm-xreg.py
@@ -14,10 +14,10 @@ seq_div_size = 200
layers = [
{'dim': 1024,
- 'xreg': (768, 0.1, 10, 10, 10, 2)
+ 'xreg': (768, 0.1, 10, 20, 10, 0)
},
{'dim': 1024,
- 'xreg': (768, 0.1, 10, 10, 10, 5)
+ 'xreg': (768, 0.1, 10, 20, 10, 0)
},
{'dim': 1024,
},
diff --git a/ircbot.py b/ircbot.py
index 7477691..07f6ff3 100755
--- a/ircbot.py
+++ b/ircbot.py
@@ -16,6 +16,8 @@ import datastream
from paramsaveload import SaveLoadParams
from blocks.graph import ComputationGraph
+import random
+
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
@@ -67,11 +69,13 @@ class IRCClient(SimpleIRCClient):
def str2data(self, s):
return numpy.array([ord(x) for x in s], dtype='int16')[None, :]
- def pred_until(self, pred_f, prob, delim='\n'):
+ def pred_until(self, pred_f, prob, delim='\n', forbid_first=None):
s = ''
while True:
prob = prob / 1.00001
pred = numpy.random.multinomial(1, prob[0, :]).nonzero()[0][0].astype('int16')
+ if forbid_first is not None and s == '' and int(pred) == forbid_first:
+ continue # try again
s = s + chr(int(pred))
@@ -108,9 +112,11 @@ class IRCClient(SimpleIRCClient):
# feed phrase to bot
prob, = pred_f(self.str2data(s0+'\n'))
if any(x in msg.lower() for x in [self.nick, 'frigal']):
- self.pred_until(pred_f, prob, '\t')
- prob, = pred_f(self.str2data(nick+': '))
- rep = nick + ': ' + self.pred_until(pred_f, prob)
+ if random.uniform(0, 1) < 0.3:
+ fromnick = self.pred_until(pred_f, prob, '\t', forbid_first=ord(nick[0]))
+ logger.info("from '%s'"%fromnick)
+ prob, = pred_f(self.str2data(nick+': '))
+ rep = nick + ': ' + self.pred_until(pred_f, prob)
else:
logger.warn('Recieved message on unknown channel: %s'%chan)
@@ -141,7 +147,7 @@ if __name__ == "__main__":
server = 'clipper.ens.fr'
port = 6667
nick = 'frigo'
- chans = ['#frigotest', '#courssysteme']
+ chans = ['#frigotest']
irc = IRCClient(model=m,
sample_temperature=config.sample_temperature,