From 9be0db7523abdfa59c19115585f1ee96d73d08c6 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 10 Jun 2015 15:22:18 -0400 Subject: Changes --- .gitignore | 1 + lstm.py | 75 ++++++++++++++++++++++++++++++++++++++++---------------------- train.py | 64 ++++++++++++++++++++++++++++++++--------------------- 3 files changed, 88 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index cec3f1e..3d6982a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.pyc *.swp data/* +model_data/* diff --git a/lstm.py b/lstm.py index 32cdb9b..e294793 100644 --- a/lstm.py +++ b/lstm.py @@ -13,36 +13,48 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout # An epoch will be composed of 'num_seqs' sequences of len 'seq_len' # divided in chunks of lengh 'seq_div_size' -num_seqs = 10 +num_seqs = 20 seq_len = 2000 seq_div_size = 100 io_dim = 256 -hidden_dims = [512, 512] +hidden_dims = [512, 512, 512] activation_function = Tanh() -all_hidden_for_output = False +i2h_all = True # input to all hidden layers or only first layer +h2o_all = True # all hiden layers to output or only last layer w_noise_std = 0.01 i_dropout = 0.5 -step_rule = 'adadelta' +step_rule = 'momentum' +learning_rate = 0.1 +momentum = 0.9 -param_desc = '%s-%sHO-n%s-d%s-%dx%d(%d)-%s' % ( +param_desc = '%s-%sIH,%sHO-n%s-d%s-%dx%d(%d)-%s' % ( repr(hidden_dims), - 'all' if all_hidden_for_output else 'last', + 'all' if i2h_all else 'first', + 'all' if h2o_all else 'last', repr(w_noise_std), repr(i_dropout), num_seqs, seq_len, seq_div_size, step_rule ) +save_freq = 5 + +# parameters for sample generation +sample_len = 60 +sample_temperature = 0.3 + if step_rule == 'rmsprop': step_rule = RMSProp() elif step_rule == 'adadelta': step_rule = AdaDelta() +elif step_rule == 'momentum': + step_rule = Momentum(learning_rate=learning_rate, momentum=momentum) else: assert(False) @@ -52,7 +64,9 @@ class Model(): in_onehot = tensor.eq(tensor.arange(io_dim, dtype='int16').reshape((1, 1, io_dim)), inp[:, :, None]) + in_onehot.name = 'in_onehot' + # Construct hidden states dims = [io_dim] + hidden_dims states = [in_onehot.dimshuffle(1, 0, 2)] bricks = [] @@ -65,38 +79,44 @@ class Model(): linear = Linear(input_dim=dims[i-1], output_dim=4*dims[i], name="lstm_in_%d"%i) + bricks.append(linear) + inter = linear.apply(states[-1]) + + if i2h_all and i > 1: + linear2 = Linear(input_dim=dims[0], output_dim=4*dims[i], + name="lstm_in0_%d"%i) + bricks.append(linear2) + inter = inter + linear2.apply(states[0]) + inter.name = 'inter_bis_%d'%i + lstm = LSTM(dim=dims[i], activation=activation_function, name="lstm_rec_%d"%i) + bricks.append(lstm) - new_states, new_cells = lstm.apply(linear.apply(states[-1]), + new_states, new_cells = lstm.apply(inter, states=init_state, cells=init_cell) updates.append((init_state, new_states[-1, :, :])) updates.append((init_cell, new_cells[-1, :, :])) states.append(new_states) - bricks = bricks + [linear, lstm] - states = [s.dimshuffle(1, 0, 2).reshape((inp.shape[0] * inp.shape[1], dim)) - for dim, s in zip(dims, states)] + states = [s.dimshuffle(1, 0, 2) for s in states] - if all_hidden_for_output: - top_linear = MLP(dims=[sum(hidden_dims), io_dim], - activations=[Softmax()], - name="pred_mlp") + # Construct output from hidden states + out = None + layers = zip(dims, states)[1:] + if not h2o_all: + layers = [layers[-1]] + for i, (dim, state) in enumerate(layers): + top_linear = Linear(input_dim=dim, output_dim=io_dim, + name='top_linear_%d'%i) bricks.append(top_linear) + out_i = top_linear.apply(state) + out = out_i if out is None else out + out_i + out.name = 'out_part_%d'%i - out = top_linear.apply(tensor.concatenate(states[1:], axis=1)) - else: - top_linear = MLP(dims=[hidden_dims[-1], io_dim], - activations=[None], - name="pred_mlp") - bricks.append(top_linear) - - out = top_linear.apply(states[-1]) - - out = out.reshape((inp.shape[0], inp.shape[1], io_dim)) - + # Do prediction and calculate cost pred = out.argmax(axis=2) cost = Softmax().categorical_cross_entropy(inp[:, 1:].flatten(), @@ -104,13 +124,13 @@ class Model(): io_dim))) error_rate = tensor.neq(inp[:, 1:].flatten(), pred[:, :-1].flatten()).mean() - # Initialize + # Initialize all bricks for brick in bricks: brick.weights_init = IsotropicGaussian(0.1) brick.biases_init = Constant(0.) brick.initialize() - # apply noise + # Apply noise and dropout cg = ComputationGraph([cost, error_rate]) if w_noise_std > 0: noise_vars = VariableFilter(roles=[WEIGHT])(cg) @@ -123,6 +143,7 @@ class Model(): self.error_rate = error_rate self.cost_reg = cost_reg self.error_rate_reg = error_rate_reg + self.out = out self.pred = pred self.updates = updates diff --git a/train.py b/train.py index 7857f3f..a8e9ef2 100755 --- a/train.py +++ b/train.py @@ -5,14 +5,17 @@ import numpy import sys import importlib +from contextlib import closing + import theano from theano import tensor +from theano.tensor.shared_randomstreams import RandomStreams -from blocks.dump import load_parameter_values -from blocks.dump import MainLoopDumpManager +from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -from blocks.extensions.plot import Plot +from blocks.extras.extensions.plot import Plot +from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model @@ -37,10 +40,14 @@ class GenText(SimpleExtension): self.init_text = init_text self.max_bytes = max_bytes - cg = ComputationGraph([model.pred]) + + out = model.out[:, -1, :] / numpy.float32(config.sample_temperature) + prob = tensor.nnet.softmax(out) + + cg = ComputationGraph([prob]) assert(len(cg.inputs) == 1) assert(cg.inputs[0].name == 'bytes') - self.f = theano.function(inputs=cg.inputs, outputs=[model.pred]) + self.f = theano.function(inputs=cg.inputs, outputs=[prob]) super(GenText, self).__init__(**kwargs) @@ -49,22 +56,21 @@ class GenText(SimpleExtension): dtype='int16')[None, :].repeat(axis=0, repeats=config.num_seqs) while v.shape[1] < self.max_bytes: - pred, = self.f(v) - v = numpy.concatenate([v, pred[:, -1:]], axis=1) + prob, = self.f(v) + prob = prob / 1.00001 + pred = numpy.zeros((prob.shape[0],), dtype='int16') + for i in range(prob.shape[0]): + pred[i] = numpy.random.multinomial(1, prob[i, :]).nonzero()[0][0] + v = numpy.concatenate([v, pred[:, None]], axis=1) for i in range(v.shape[0]): print "Sample:", ''.join([chr(int(v[i, j])) for j in range(v.shape[1])]) -def train_model(m, train_stream, load_location=None, save_location=None): +def train_model(m, train_stream, dump_path=None): # Define the model model = Model(m.cost) - # Load the parameters from a dumped model - if load_location is not None: - logger.info('Loading parameters...') - model.set_param_values(load_parameter_values(load_location)) - cg = ComputationGraph(m.cost_reg) algorithm = GradientDescent(cost=m.cost_reg, step_rule=config.step_rule, @@ -72,11 +78,26 @@ def train_model(m, train_stream, load_location=None, save_location=None): algorithm.add_updates(m.updates) + # Load the parameters from a dumped model + if dump_path is not None: + try: + logger.info('Loading parameters...') + with closing(numpy.load(dump_path)) as source: + param_values = {'/' + name.replace(BRICK_DELIMITER, '/'): source[name] + for name in source.keys() + if name != 'pkl' and not 'None' in name} + model.set_param_values(param_values) + except IOError: + pass + main_loop = MainLoop( model=model, data_stream=train_stream, algorithm=algorithm, extensions=[ + Checkpoint(path=dump_path, + after_epoch=False, every_n_epochs=config.save_freq), + TrainingDataMonitoring( [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], prefix='train', every_n_epochs=1), @@ -84,19 +105,14 @@ def train_model(m, train_stream, load_location=None, save_location=None): Plot(document='tr_'+model_name+'_'+config.param_desc, channels=[['train_cost', 'train_cost_reg'], ['train_error_rate', 'train_error_rate_reg']], + server_url='http://eos21:4201/', every_n_epochs=1, after_epoch=False), - GenText(m, '\t', 20, every_n_epochs=1, after_epoch=False) + + GenText(m, ' ', config.sample_len, every_n_epochs=1, after_epoch=False) ] ) main_loop.run() - # Save the main loop - if save_location is not None: - logger.info('Saving the main loop...') - dump_manager = MainLoopDumpManager(save_location) - dump_manager.dump(main_loop) - logger.info('Saved') - if __name__ == "__main__": # Build datastream @@ -114,8 +130,6 @@ if __name__ == "__main__": m.pred.name = 'pred' # Train the model - saveloc = 'model_data/%s' % model_name - train_model(m, train_stream, - load_location=None, - save_location=None) + saveloc = 'model_data/%s-%s' % (model_name, config.param_desc) + train_model(m, train_stream, dump_path=saveloc) -- cgit v1.2.3