diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 14:58:20 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 14:58:20 -0400 |
commit | 701c407b8c87a9270b31d34ac54e683341be661e (patch) | |
tree | 1ef80f0859a7f592d44e9f1c4010da02ee68fb10 | |
parent | e91e14e894196642532c0b7be50b01c1354ad702 (diff) | |
download | text-rnn-701c407b8c87a9270b31d34ac54e683341be661e.tar.gz text-rnn-701c407b8c87a9270b31d34ac54e683341be661e.zip |
xoxo
-rw-r--r-- | gfgru.py | 103 | ||||
-rw-r--r-- | ircext.py | 26 | ||||
-rw-r--r-- | lstm.py | 3 | ||||
-rwxr-xr-x | train.py | 42 |
4 files changed, 104 insertions, 70 deletions
@@ -15,24 +15,26 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout # An epoch will be composed of 'num_seqs' sequences of len 'seq_len' # divided in chunks of lengh 'seq_div_size' -num_seqs = 10 -seq_len = 2000 -seq_div_size = 100 +num_seqs = 2 +seq_len = 2 +seq_div_size = 2 io_dim = 256 recurrent_blocks = [ # (256, Tanh(), [2048], [Rectifier()]), - (256, Tanh(), [], []), - (256, Tanh(), [], []), - (256, Tanh(), [512], [Rectifier()]), - (256, Tanh(), [512], [Rectifier()]), + (384, Tanh(), [], []), + (384, Tanh(), [], []), + (384, Tanh(), [1024], [Rectifier()]), +# (384, Tanh(), [1024], [Rectifier()]), +# (2, Tanh(), [2], [Rectifier()]), +# (2, Tanh(), [], []), ] -control_hidden = [512] -control_hidden_activations = [Tanh()] +control_hidden = [1024] +control_hidden_activations = [Rectifier()] -output_hidden = [512] +output_hidden = [1024] output_hidden_activations = [Rectifier()] weight_noise_std = 0.02 @@ -54,11 +56,12 @@ param_desc = '%s,c%s,o%s-n%s-d%s,%s-%dx%d(%d)-%s' % ( ) save_freq = 1 +on_irc = False # parameters for sample generation sample_len = 100 sample_temperature = 0.7 #0.5 -sample_freq = 1 +sample_freq = None if step_rule == 'rmsprop': step_rule = RMSProp() @@ -71,8 +74,7 @@ else: class GFGRU(BaseRecurrent, Initializable): - @lazy(allocation=['input_dim', 'recurrent_blocks', 'control_hidden', 'control_hidden_activations']) - def __init__(self, input_dim=None, recurrent_blocks=None, control_hidden=None, control_hidden_activations=None, **kwargs): + def __init__(self, input_dim, recurrent_blocks, control_hidden, control_hidden_activations, **kwargs): super(GFGRU, self).__init__(**kwargs) self.input_dim = input_dim @@ -80,9 +82,8 @@ class GFGRU(BaseRecurrent, Initializable): self.control_hidden = control_hidden self.control_hidden_activations = control_hidden_activations + # setup children self.children = control_hidden_activations - - def _allocate(self): for (_, a, _, b) in recurrent_blocks: self.children.append(a) for c in b: @@ -93,17 +94,20 @@ class GFGRU(BaseRecurrent, Initializable): self.hidden_total_dim = sum(x for (x, _, _, _) in self.recurrent_blocks) + # control block + self.cblocklen = len(self.recurrent_blocks) + 2 + control_idim = self.hidden_total_dim + self.input_dim - control_odim = len(self.recurrent_blocks) * (len(self.recurrent_blocks) + 2) + control_odim = len(self.recurrent_blocks) * self.cblocklen self.control = MLP(dims=[control_idim] + self.control_hidden + [control_odim], activations=self.control_hidden_activations + [logistic], name='control') self.children.append(self.control) + # recurrent blocks self.blocks = [] self.params = [] - self.initial_states = {} for i, (dim, act, hdim, hact) in enumerate(self.recurrent_blocks): idim = self.input_dim + self.hidden_total_dim if i > 0: @@ -122,33 +126,37 @@ class GFGRU(BaseRecurrent, Initializable): self.children.append(brick) self.blocks.append((rgate, inter, zgate, nstate)) - init_states = shared_floatx_zeros((self.hidden_total_dim,), name='initial_states') - self.params = [init_states] - add_role(self.params[0], INITIAL_STATE) + # init state zeros + self.init_states_names = [] + self.init_states_dict = {} + self.params = [] - def get_dim(self, name): - if name == 'states': - return self.hidden_total_dim - return super(GFLSTM, self).get_dim(name) + for i, (dim, _, _, _) in enumerate(self.recurrent_blocks): + name = 'init_state_%d'%i + svar = shared_floatx_zeros((dim,), name=name) + add_role(svar, INITIAL_STATE) - @recurrent(sequences=['inputs'], states=['states'], - outputs=['states'], contexts=[]) - def apply(self, inputs=None, states=None): - concat_states = states + self.init_states_names.append(name) + self.init_states_dict[name] = svar + self.params.append(svar) - states = [] - offset = 0 - for (dim, _, _, _) in self.recurrent_blocks: - states.append(concat_states[:, offset:offset+dim]) - offset += dim + def get_dim(self, name): + if name in self.init_states_dict: + return self.init_states_dict[name].shape.eval() + return super(GFGRU, self).get_dim(name) + + @recurrent(sequences=['inputs'], contexts=[]) + def apply(self, inputs=None, **kwargs): + states = [kwargs[i] for i in self.init_states_names] + concat_states = tensor.concatenate(states, axis=1) concat_input_states = tensor.concatenate([inputs, concat_states], axis=1) - control = self.control.apply(concat_input_states) + control_v = self.control.apply(concat_input_states) new_states = [] for i, (rgate, inter, zgate, nstate) in enumerate(self.blocks): - controls = control[:, i * (len(self.recurrent_blocks)+2):(i+1) * (len(self.recurrent_blocks)+2)] + controls = control_v[:, i * self.cblocklen:(i+1) * self.cblocklen] rgate_v = rgate.apply(concat_states) r_inputs = tensor.concatenate([s * controls[:, j][:, None] for j, s in enumerate(states)], axis=1) r_inputs = r_inputs * (1 - rgate_v * controls[:, -1][:, None]) @@ -162,18 +170,28 @@ class GFGRU(BaseRecurrent, Initializable): zgate_v = zgate.apply(inter_v) nstate_v = nstate.apply(inter_v) - nstate_v = nstate_v * (1 - zgate_v * controls[:, -2][:, None]) + zctl = zgate_v * controls[:, -2][:, None] + nstate_v = zctl * nstate_v + (1 - zctl) * states[i] new_states.append(nstate_v) - return tensor.concatenate(new_states, axis=1) + return new_states + + @apply.property('states') + def apply_states(self): + return self.init_states_names + + @apply.property('outputs') + def apply_outputs(self): + return self.init_states_names @application def initial_state(self, state_name, batch_size, *args, **kwargs): - return tensor.repeat(self.params[0][None, :], repeats=batch_size, axis=0) + return tensor.repeat(self.init_states_dict[state_name][None, :], + repeats=batch_size, + axis=0) - class Model(): def __init__(self): inp = tensor.lmatrix('bytes') @@ -191,14 +209,15 @@ class Model(): prev_states = theano.shared(numpy.zeros((num_seqs, hidden_total_dim)).astype(theano.config.floatX), name='states_save') - states = gfgru.apply(in_onehot.dimshuffle(1, 0, 2), - states=prev_states).dimshuffle(1, 0, 2) + states = [x.dimshuffle(1, 0, 2) for x in gfgru.apply(in_onehot.dimshuffle(1, 0, 2), states=prev_states)] + states = tensor.concatenate(states, axis=2) new_states = states[:, -1, :] out_mlp = MLP(dims=[hidden_total_dim] + output_hidden + [io_dim], activations=output_hidden_activations + [None], name='output_mlp') - out = out_mlp.apply(states.reshape((inp.shape[0]*inp.shape[1], hidden_total_dim))).reshape((inp.shape[0], inp.shape[1], io_dim)) + states_sh = states.reshape((inp.shape[0]*inp.shape[1], hidden_total_dim)) + out = out_mlp.apply(states_sh).reshape((inp.shape[0], inp.shape[1], io_dim)) @@ -30,7 +30,10 @@ class IRCClient(SimpleIRCClient): def on_join(self, server, ev): self.server = server - def pred_line(self, pred_f, prob): + def str2data(self, s): + return numpy.array([ord(x) for x in s], dtype='int16')[None, :] + + def pred_until(self, pred_f, prob, delim='\n'): s = '' while True: prob = prob / 1.00001 @@ -40,11 +43,13 @@ class IRCClient(SimpleIRCClient): prob, = pred_f(pred[None, None]) - if s[-1] == '\n': + if s[-1] == delim: break return s[:-1] def privmsg(self, chan, msg): + if len(msg) > 500: + msg = 'blip bloup' logger.info("%s >> %s" % (chan, msg)) self.server.privmsg(chan, msg.decode('utf-8', 'ignore')) @@ -62,21 +67,20 @@ class IRCClient(SimpleIRCClient): if chan in self.chans: pred_f, _ = self.chans[chan] if s0[-2:] == '^I': - prob, = pred_f(numpy.array([ord(x) for x in s0[:-2]], dtype='int16')[None, :]) - s = self.pred_line(pred_f, prob) - rep = s0[:-2] + s + prob, = pred_f(self.str2data(s0[:-2])) + rep = s0[:-2] + self.pred_until(pred_f, prob) + rep = rep.split('\t', 1)[-1] else: # feed phrase to bot - prob, = pred_f(numpy.array([ord(x) for x in s0+'\n'], dtype='int16')[None, :]) - if msg[:len(self.nick)+1] == self.nick+':': - #TODO: make it so that it predicts a message beginning by 'nick: ' - # (ie it responds to the person who pinged it) - rep = self.pred_line(pred_f, prob) + prob, = pred_f(self.str2data(s0+'\n')) + if self.nick in msg: + self.pred_until(pred_f, prob, '\t') + prob, = pred_f(self.str2data(nick+': ')) + rep = nick + ': ' + self.pred_until(pred_f, prob) else: pass if rep != None: - rep = rep.split('\t', 1)[1] self.privmsg(chan, rep) class IRCClientExt(SimpleExtension): @@ -47,11 +47,12 @@ param_desc = '%s-%sIH,%sHO-n%s-d%s-l1r%s-%dx%d(%d)-%s' % ( ) save_freq = 5 +on_irc = True # parameters for sample generation sample_len = 1000 sample_temperature = 0.7 #0.5 -sample_freq = 10 +sample_freq = None if step_rule == 'rmsprop': step_rule = RMSProp() @@ -72,16 +72,36 @@ def train_model(m, train_stream, dump_path=None): except IOError: pass + extensions = [] + if config.save_freq is not None: + extensions.append( + Checkpoint(path=dump_path, + after_epoch=False, + use_cpickle=True, + every_n_epochs=config.save_freq), + ) + if config.sample_freq is not None: + extensions.append( + gentext.GenText(m, '\nalex\ttu crois ?\n', + config.sample_len, config.sample_temperature, + every_n_epochs=config.sample_freq, + after_epoch=False, before_training=True), + ) + if config.on_irc: + extensions.append( + ircext.IRCClientExt(m, config.sample_temperature, + server='irc.ulminfo.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True), + ) + main_loop = MainLoop( model=model, data_stream=train_stream, algorithm=algorithm, - extensions=[ - #Checkpoint(path=dump_path, - # after_epoch=False, - # use_cpickle=True, - # every_n_epochs=config.save_freq), - + extensions=extensions + [ TrainingDataMonitoring( [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], prefix='train', every_n_epochs=1), @@ -92,16 +112,6 @@ def train_model(m, train_stream, dump_path=None): server_url='http://eos21:4201/', every_n_epochs=1, after_epoch=False), - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True), - #ircext.IRCClientExt(m, config.sample_temperature, - # server='irc.ulminfo.fr', - # port=6667, - # nick='frigo', - # channels=['#frigotest', '#courssysteme'], - # after_batch=True), ResetStates([v for v, _ in m.states], after_epoch=True) ] ) |