summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gfgru.py103
-rw-r--r--ircext.py26
-rw-r--r--lstm.py3
-rwxr-xr-xtrain.py42
4 files changed, 104 insertions, 70 deletions
diff --git a/gfgru.py b/gfgru.py
index 9a612e1..c4b4b48 100644
--- a/gfgru.py
+++ b/gfgru.py
@@ -15,24 +15,26 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout
# An epoch will be composed of 'num_seqs' sequences of len 'seq_len'
# divided in chunks of lengh 'seq_div_size'
-num_seqs = 10
-seq_len = 2000
-seq_div_size = 100
+num_seqs = 2
+seq_len = 2
+seq_div_size = 2
io_dim = 256
recurrent_blocks = [
# (256, Tanh(), [2048], [Rectifier()]),
- (256, Tanh(), [], []),
- (256, Tanh(), [], []),
- (256, Tanh(), [512], [Rectifier()]),
- (256, Tanh(), [512], [Rectifier()]),
+ (384, Tanh(), [], []),
+ (384, Tanh(), [], []),
+ (384, Tanh(), [1024], [Rectifier()]),
+# (384, Tanh(), [1024], [Rectifier()]),
+# (2, Tanh(), [2], [Rectifier()]),
+# (2, Tanh(), [], []),
]
-control_hidden = [512]
-control_hidden_activations = [Tanh()]
+control_hidden = [1024]
+control_hidden_activations = [Rectifier()]
-output_hidden = [512]
+output_hidden = [1024]
output_hidden_activations = [Rectifier()]
weight_noise_std = 0.02
@@ -54,11 +56,12 @@ param_desc = '%s,c%s,o%s-n%s-d%s,%s-%dx%d(%d)-%s' % (
)
save_freq = 1
+on_irc = False
# parameters for sample generation
sample_len = 100
sample_temperature = 0.7 #0.5
-sample_freq = 1
+sample_freq = None
if step_rule == 'rmsprop':
step_rule = RMSProp()
@@ -71,8 +74,7 @@ else:
class GFGRU(BaseRecurrent, Initializable):
- @lazy(allocation=['input_dim', 'recurrent_blocks', 'control_hidden', 'control_hidden_activations'])
- def __init__(self, input_dim=None, recurrent_blocks=None, control_hidden=None, control_hidden_activations=None, **kwargs):
+ def __init__(self, input_dim, recurrent_blocks, control_hidden, control_hidden_activations, **kwargs):
super(GFGRU, self).__init__(**kwargs)
self.input_dim = input_dim
@@ -80,9 +82,8 @@ class GFGRU(BaseRecurrent, Initializable):
self.control_hidden = control_hidden
self.control_hidden_activations = control_hidden_activations
+ # setup children
self.children = control_hidden_activations
-
- def _allocate(self):
for (_, a, _, b) in recurrent_blocks:
self.children.append(a)
for c in b:
@@ -93,17 +94,20 @@ class GFGRU(BaseRecurrent, Initializable):
self.hidden_total_dim = sum(x for (x, _, _, _) in self.recurrent_blocks)
+ # control block
+ self.cblocklen = len(self.recurrent_blocks) + 2
+
control_idim = self.hidden_total_dim + self.input_dim
- control_odim = len(self.recurrent_blocks) * (len(self.recurrent_blocks) + 2)
+ control_odim = len(self.recurrent_blocks) * self.cblocklen
self.control = MLP(dims=[control_idim] + self.control_hidden + [control_odim],
activations=self.control_hidden_activations + [logistic],
name='control')
self.children.append(self.control)
+ # recurrent blocks
self.blocks = []
self.params = []
- self.initial_states = {}
for i, (dim, act, hdim, hact) in enumerate(self.recurrent_blocks):
idim = self.input_dim + self.hidden_total_dim
if i > 0:
@@ -122,33 +126,37 @@ class GFGRU(BaseRecurrent, Initializable):
self.children.append(brick)
self.blocks.append((rgate, inter, zgate, nstate))
- init_states = shared_floatx_zeros((self.hidden_total_dim,), name='initial_states')
- self.params = [init_states]
- add_role(self.params[0], INITIAL_STATE)
+ # init state zeros
+ self.init_states_names = []
+ self.init_states_dict = {}
+ self.params = []
- def get_dim(self, name):
- if name == 'states':
- return self.hidden_total_dim
- return super(GFLSTM, self).get_dim(name)
+ for i, (dim, _, _, _) in enumerate(self.recurrent_blocks):
+ name = 'init_state_%d'%i
+ svar = shared_floatx_zeros((dim,), name=name)
+ add_role(svar, INITIAL_STATE)
- @recurrent(sequences=['inputs'], states=['states'],
- outputs=['states'], contexts=[])
- def apply(self, inputs=None, states=None):
- concat_states = states
+ self.init_states_names.append(name)
+ self.init_states_dict[name] = svar
+ self.params.append(svar)
- states = []
- offset = 0
- for (dim, _, _, _) in self.recurrent_blocks:
- states.append(concat_states[:, offset:offset+dim])
- offset += dim
+ def get_dim(self, name):
+ if name in self.init_states_dict:
+ return self.init_states_dict[name].shape.eval()
+ return super(GFGRU, self).get_dim(name)
+
+ @recurrent(sequences=['inputs'], contexts=[])
+ def apply(self, inputs=None, **kwargs):
+ states = [kwargs[i] for i in self.init_states_names]
+ concat_states = tensor.concatenate(states, axis=1)
concat_input_states = tensor.concatenate([inputs, concat_states], axis=1)
- control = self.control.apply(concat_input_states)
+ control_v = self.control.apply(concat_input_states)
new_states = []
for i, (rgate, inter, zgate, nstate) in enumerate(self.blocks):
- controls = control[:, i * (len(self.recurrent_blocks)+2):(i+1) * (len(self.recurrent_blocks)+2)]
+ controls = control_v[:, i * self.cblocklen:(i+1) * self.cblocklen]
rgate_v = rgate.apply(concat_states)
r_inputs = tensor.concatenate([s * controls[:, j][:, None] for j, s in enumerate(states)], axis=1)
r_inputs = r_inputs * (1 - rgate_v * controls[:, -1][:, None])
@@ -162,18 +170,28 @@ class GFGRU(BaseRecurrent, Initializable):
zgate_v = zgate.apply(inter_v)
nstate_v = nstate.apply(inter_v)
- nstate_v = nstate_v * (1 - zgate_v * controls[:, -2][:, None])
+ zctl = zgate_v * controls[:, -2][:, None]
+ nstate_v = zctl * nstate_v + (1 - zctl) * states[i]
new_states.append(nstate_v)
- return tensor.concatenate(new_states, axis=1)
+ return new_states
+
+ @apply.property('states')
+ def apply_states(self):
+ return self.init_states_names
+
+ @apply.property('outputs')
+ def apply_outputs(self):
+ return self.init_states_names
@application
def initial_state(self, state_name, batch_size, *args, **kwargs):
- return tensor.repeat(self.params[0][None, :], repeats=batch_size, axis=0)
+ return tensor.repeat(self.init_states_dict[state_name][None, :],
+ repeats=batch_size,
+ axis=0)
-
class Model():
def __init__(self):
inp = tensor.lmatrix('bytes')
@@ -191,14 +209,15 @@ class Model():
prev_states = theano.shared(numpy.zeros((num_seqs, hidden_total_dim)).astype(theano.config.floatX),
name='states_save')
- states = gfgru.apply(in_onehot.dimshuffle(1, 0, 2),
- states=prev_states).dimshuffle(1, 0, 2)
+ states = [x.dimshuffle(1, 0, 2) for x in gfgru.apply(in_onehot.dimshuffle(1, 0, 2), states=prev_states)]
+ states = tensor.concatenate(states, axis=2)
new_states = states[:, -1, :]
out_mlp = MLP(dims=[hidden_total_dim] + output_hidden + [io_dim],
activations=output_hidden_activations + [None],
name='output_mlp')
- out = out_mlp.apply(states.reshape((inp.shape[0]*inp.shape[1], hidden_total_dim))).reshape((inp.shape[0], inp.shape[1], io_dim))
+ states_sh = states.reshape((inp.shape[0]*inp.shape[1], hidden_total_dim))
+ out = out_mlp.apply(states_sh).reshape((inp.shape[0], inp.shape[1], io_dim))
diff --git a/ircext.py b/ircext.py
index 3e38ac2..1af2ba8 100644
--- a/ircext.py
+++ b/ircext.py
@@ -30,7 +30,10 @@ class IRCClient(SimpleIRCClient):
def on_join(self, server, ev):
self.server = server
- def pred_line(self, pred_f, prob):
+ def str2data(self, s):
+ return numpy.array([ord(x) for x in s], dtype='int16')[None, :]
+
+ def pred_until(self, pred_f, prob, delim='\n'):
s = ''
while True:
prob = prob / 1.00001
@@ -40,11 +43,13 @@ class IRCClient(SimpleIRCClient):
prob, = pred_f(pred[None, None])
- if s[-1] == '\n':
+ if s[-1] == delim:
break
return s[:-1]
def privmsg(self, chan, msg):
+ if len(msg) > 500:
+ msg = 'blip bloup'
logger.info("%s >> %s" % (chan, msg))
self.server.privmsg(chan, msg.decode('utf-8', 'ignore'))
@@ -62,21 +67,20 @@ class IRCClient(SimpleIRCClient):
if chan in self.chans:
pred_f, _ = self.chans[chan]
if s0[-2:] == '^I':
- prob, = pred_f(numpy.array([ord(x) for x in s0[:-2]], dtype='int16')[None, :])
- s = self.pred_line(pred_f, prob)
- rep = s0[:-2] + s
+ prob, = pred_f(self.str2data(s0[:-2]))
+ rep = s0[:-2] + self.pred_until(pred_f, prob)
+ rep = rep.split('\t', 1)[-1]
else:
# feed phrase to bot
- prob, = pred_f(numpy.array([ord(x) for x in s0+'\n'], dtype='int16')[None, :])
- if msg[:len(self.nick)+1] == self.nick+':':
- #TODO: make it so that it predicts a message beginning by 'nick: '
- # (ie it responds to the person who pinged it)
- rep = self.pred_line(pred_f, prob)
+ prob, = pred_f(self.str2data(s0+'\n'))
+ if self.nick in msg:
+ self.pred_until(pred_f, prob, '\t')
+ prob, = pred_f(self.str2data(nick+': '))
+ rep = nick + ': ' + self.pred_until(pred_f, prob)
else:
pass
if rep != None:
- rep = rep.split('\t', 1)[1]
self.privmsg(chan, rep)
class IRCClientExt(SimpleExtension):
diff --git a/lstm.py b/lstm.py
index dbe46dc..1750d58 100644
--- a/lstm.py
+++ b/lstm.py
@@ -47,11 +47,12 @@ param_desc = '%s-%sIH,%sHO-n%s-d%s-l1r%s-%dx%d(%d)-%s' % (
)
save_freq = 5
+on_irc = True
# parameters for sample generation
sample_len = 1000
sample_temperature = 0.7 #0.5
-sample_freq = 10
+sample_freq = None
if step_rule == 'rmsprop':
step_rule = RMSProp()
diff --git a/train.py b/train.py
index b59cd8e..79b2116 100755
--- a/train.py
+++ b/train.py
@@ -72,16 +72,36 @@ def train_model(m, train_stream, dump_path=None):
except IOError:
pass
+ extensions = []
+ if config.save_freq is not None:
+ extensions.append(
+ Checkpoint(path=dump_path,
+ after_epoch=False,
+ use_cpickle=True,
+ every_n_epochs=config.save_freq),
+ )
+ if config.sample_freq is not None:
+ extensions.append(
+ gentext.GenText(m, '\nalex\ttu crois ?\n',
+ config.sample_len, config.sample_temperature,
+ every_n_epochs=config.sample_freq,
+ after_epoch=False, before_training=True),
+ )
+ if config.on_irc:
+ extensions.append(
+ ircext.IRCClientExt(m, config.sample_temperature,
+ server='irc.ulminfo.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True),
+ )
+
main_loop = MainLoop(
model=model,
data_stream=train_stream,
algorithm=algorithm,
- extensions=[
- #Checkpoint(path=dump_path,
- # after_epoch=False,
- # use_cpickle=True,
- # every_n_epochs=config.save_freq),
-
+ extensions=extensions + [
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
@@ -92,16 +112,6 @@ def train_model(m, train_stream, dump_path=None):
server_url='http://eos21:4201/',
every_n_epochs=1, after_epoch=False),
- gentext.GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True),
- #ircext.IRCClientExt(m, config.sample_temperature,
- # server='irc.ulminfo.fr',
- # port=6667,
- # nick='frigo',
- # channels=['#frigotest', '#courssysteme'],
- # after_batch=True),
ResetStates([v for v, _ in m.states], after_epoch=True)
]
)