diff options
-rw-r--r-- | cchlstm.py | 234 | ||||
-rw-r--r-- | dgsrnn.py | 33 | ||||
-rw-r--r-- | paramsaveload.py | 4 | ||||
-rwxr-xr-x | train.py | 51 |
4 files changed, 286 insertions, 36 deletions
diff --git a/cchlstm.py b/cchlstm.py new file mode 100644 index 0000000..9ff2016 --- /dev/null +++ b/cchlstm.py @@ -0,0 +1,234 @@ +import theano +from theano import tensor +import numpy + +from blocks.algorithms import Momentum, AdaDelta, RMSProp +from blocks.bricks import Tanh, Softmax, Linear, MLP, Initializable +from blocks.bricks.lookup import LookupTable +from blocks.bricks.recurrent import LSTM, BaseRecurrent, recurrent +from blocks.initialization import IsotropicGaussian, Constant + +from blocks.filter import VariableFilter +from blocks.roles import WEIGHT +from blocks.graph import ComputationGraph, apply_noise, apply_dropout + +# An epoch will be composed of 'num_seqs' sequences of len 'seq_len' +# divided in chunks of lengh 'seq_div_size' +num_seqs = 10 +seq_len = 2000 +seq_div_size = 200 + +io_dim = 256 + +# Model structure +hidden_dims = [256, 256, 256] +activation_function = Tanh() + +cond_cert = [0.5, 0.5] + +# Regularization +w_noise_std = 0.02 + +# Step rule +step_rule = 'adadelta' +learning_rate = 0.1 +momentum = 0.9 + + +param_desc = '%s(p%s)-n%s-%dx%d(%d)-%s' % ( + repr(hidden_dims), repr(cond_cert), + repr(w_noise_std), + num_seqs, seq_len, seq_div_size, + step_rule + ) + +save_freq = 5 +on_irc = False + +# parameters for sample generation +sample_len = 200 +sample_temperature = 0.7 #0.5 +sample_freq = 1 + +if step_rule == 'rmsprop': + step_rule = RMSProp() +elif step_rule == 'adadelta': + step_rule = AdaDelta() +elif step_rule == 'momentum': + step_rule = Momentum(learning_rate=learning_rate, momentum=momentum) +else: + assert(False) + +class CCHLSTM(BaseRecurrent, Initializable): + def __init__(self, io_dim, hidden_dims, cond_cert, activation=None, **kwargs): + super(CCHLSTM, self).__init__(**kwargs) + + self.cond_cert = cond_cert + + self.io_dim = io_dim + self.hidden_dims = hidden_dims + + self.children = [] + self.layers = [] + + self.softmax = Softmax() + self.children.append(self.softmax) + + for i, d in enumerate(hidden_dims): + i0 = LookupTable(length=io_dim, + dim=4*d, + name='i0-%d'%i) + self.children.append(i0) + + if i > 0: + i1 = Linear(input_dim=hidden_dims[i-1], + output_dim=4*d, + name='i1-%d'%i) + self.children.append(i1) + else: + i1 = None + + lstm = LSTM(dim=d, activation=activation, + name='LSTM-%d'%i) + self.children.append(lstm) + + o = Linear(input_dim=d, + output_dim=io_dim, + name='o-%d'%i) + self.children.append(o) + + self.layers.append((i0, i1, lstm, o)) + + + @recurrent(sequences=['inputs'], contexts=[]) + def apply(self, inputs, **kwargs): + + l0i, _, l0l, l0o = self.layers[0] + l0iv = l0i.apply(inputs) + new_states0, new_cells0 = l0l.apply(states=kwargs['states0'], + cells=kwargs['cells0'], + inputs=l0iv, + iterate=False) + l0ov = l0o.apply(new_states0) + + pos = l0ov + ps = new_states0 + + passnext = tensor.ones((inputs.shape[0], 1)) + out_sc = [new_states0, new_cells0, passnext] + + for i, (cch, (i0, i1, l, o)) in enumerate(zip(self.cond_cert, self.layers[1:])): + pop = self.softmax.apply(pos) + best = pop.max(axis=1) + passnext = passnext * tensor.le(best, cch)[:, None] + + i0v = i0.apply(inputs) + i1v = i1.apply(ps) + + prev_states = kwargs['states%d'%i] + prev_cells = kwargs['cells%d'%i] + new_states, new_cells = l.apply(inputs=i0v + i1v, + states=prev_states, + cells=prev_cells, + iterate=False) + new_states = tensor.switch(passnext, new_states, prev_states) + new_cells = tensor.switch(passnext, new_cells, prev_cells) + out_sc += [new_states, new_cells, passnext] + + ov = o.apply(new_states) + pos = tensor.switch(passnext, pos + ov, pos) + ps = new_states + + return [pos] + out_sc + + def get_dim(self, name): + dims = {'pred': self.io_dim} + for i, d in enumerate(self.hidden_dims): + dims['states%d'%i] = dims['cells%d'%i] = d + if name in dims: + return dims[name] + return super(CCHLSTM, self).get_dim(name) + + @apply.property('states') + def apply_states(self): + ret = [] + for i in range(len(self.hidden_dims)): + ret += ['states%d'%i, 'cells%d'%i] + return ret + + @apply.property('outputs') + def apply_outputs(self): + ret = ['pred'] + for i in range(len(self.hidden_dims)): + ret += ['states%d'%i, 'cells%d'%i, 'active%d'%i] + return ret + + +class Model(): + def __init__(self): + inp = tensor.lmatrix('bytes') + + # Make state vars + state_vars = {} + for i, d in enumerate(hidden_dims): + state_vars['states%d'%i] = theano.shared(numpy.zeros((num_seqs, d)) + .astype(theano.config.floatX), + name='states%d'%i) + state_vars['cells%d'%i] = theano.shared(numpy.zeros((num_seqs, d)) + .astype(theano.config.floatX), + name='cells%d'%i) + # Construct brick + cchlstm = CCHLSTM(io_dim=io_dim, + hidden_dims=hidden_dims, + cond_cert=cond_cert, + activation=activation_function) + + # Apply it + outs = cchlstm.apply(inputs=inp.dimshuffle(1, 0), + **state_vars) + states = [] + active_prop = [] + for i in range(len(hidden_dims)): + states.append((state_vars['states%d'%i], outs[3*i+1][-1, :, :])) + states.append((state_vars['cells%d'%i], outs[3*i+2][-1, :, :])) + active_prop.append(outs[3*i+3].mean()) + active_prop[-1].name = 'active_prop_%d'%i + + out = outs[0].dimshuffle(1, 0, 2) + + # Do prediction and calculate cost + pred = out.argmax(axis=2) + + cost = Softmax().categorical_cross_entropy(inp[:, 1:].flatten(), + out[:, :-1, :].reshape((inp.shape[0]*(inp.shape[1]-1), + io_dim))) + error_rate = tensor.neq(inp[:, 1:].flatten(), pred[:, :-1].flatten()).mean() + + # Initialize all bricks + for brick in [cchlstm]: + brick.weights_init = IsotropicGaussian(0.1) + brick.biases_init = Constant(0.) + brick.initialize() + + # Apply noise and dropoutvars + cg = ComputationGraph([cost, error_rate]) + if w_noise_std > 0: + noise_vars = VariableFilter(roles=[WEIGHT])(cg) + cg = apply_noise(cg, noise_vars, w_noise_std) + [cost_reg, error_rate_reg] = cg.outputs + + self.sgd_cost = cost_reg + self.monitor_vars = [[cost, cost_reg], + [error_rate, error_rate_reg], + active_prop] + + cost.name = 'cost' + cost_reg.name = 'cost_reg' + error_rate.name = 'error_rate' + error_rate_reg.name = 'error_rate_reg' + + self.out = out + self.pred = pred + + self.states = states + @@ -2,6 +2,8 @@ import theano from theano import tensor import numpy +from theano.tensor.shared_randomstreams import RandomStreams + from blocks.algorithms import Momentum, AdaDelta, RMSProp, Adam from blocks.bricks import Activation, Tanh, Logistic, Softmax, Rectifier, Linear, MLP, Initializable, Identity from blocks.bricks.base import application, lazy @@ -13,6 +15,8 @@ from blocks.filter import VariableFilter from blocks.roles import WEIGHT, INITIAL_STATE, add_role from blocks.graph import ComputationGraph, apply_noise, apply_dropout +rng = RandomStreams() + class TRectifier(Activation): @application(inputs=['input_'], outputs=['output']) def apply(self, input_): @@ -21,8 +25,8 @@ class TRectifier(Activation): # An epoch will be composed of 'num_seqs' sequences of len 'seq_len' # divided in chunks of lengh 'seq_div_size' num_seqs = 10 -seq_len = 2000 -seq_div_size = 100 +seq_len = 1000 +seq_div_size = 5 io_dim = 256 @@ -36,20 +40,21 @@ output_hidden_activations = [] weight_noise_std = 0.05 -output_h_dropout = 0.5 +output_h_dropout = 0.0 +drop_update = 0.0 -l1_state = 0.01 -l1_reset = 0.01 +l1_state = 0.00 +l1_reset = 0.1 step_rule = 'momentum' -learning_rate = 0.01 +learning_rate = 0.001 momentum = 0.99 -param_desc = '%s,t%s,o%s-n%s-d%s-L1:%s,%s-%s' % ( +param_desc = '%s,t%s,o%s-n%s-d%s,%s-L1:%s,%s-%s' % ( repr(state_dim), repr(transition_hidden), repr(output_hidden), repr(weight_noise_std), - repr(output_h_dropout), + repr(output_h_dropout), repr(drop_update), repr(l1_state), repr(l1_reset), step_rule ) @@ -105,13 +110,15 @@ class DGSRNN(BaseRecurrent, Initializable): return self.state_dim return super(GFGRU, self).get_dim(name) - @recurrent(sequences=['inputs'], states=['state'], + @recurrent(sequences=['inputs', 'drop_updates_mask'], states=['state'], outputs=['state', 'reset'], contexts=[]) - def apply(self, inputs=None, state=None): + def apply(self, inputs=None, drop_updates_mask=None, state=None): inter_v = self.inter.apply(tensor.concatenate([inputs, state], axis=1)) reset_v = self.reset.apply(inter_v) update_v = self.update.apply(inter_v) + reset_v = reset_v * drop_updates_mask + new_state = state * (1 - reset_v) + reset_v * update_v return new_state, reset_v @@ -141,7 +148,11 @@ class Model(): prev_state = theano.shared(numpy.zeros((num_seqs, state_dim)).astype(theano.config.floatX), name='state') - states, resets = dgsrnn.apply(in_onehot.dimshuffle(1, 0, 2), state=prev_state) + states, resets = dgsrnn.apply(inputs=in_onehot.dimshuffle(1, 0, 2), + drop_updates_mask=rng.binomial(size=(inp.shape[1], inp.shape[0], state_dim), + p=1-drop_update, + dtype=theano.config.floatX), + state=prev_state) states = states.dimshuffle(1, 0, 2) resets = resets.dimshuffle(1, 0, 2) diff --git a/paramsaveload.py b/paramsaveload.py index e44889d..9c05926 100644 --- a/paramsaveload.py +++ b/paramsaveload.py @@ -19,13 +19,13 @@ class SaveLoadParams(SimpleExtension): def do_save(self): with open(self.path, 'w') as f: logger.info('Saving parameters to %s...'%self.path) - cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) def do_load(self): try: with open(self.path, 'r') as f: logger.info('Loading parameters from %s...'%self.path) - self.model.set_param_values(cPickle.load(f)) + self.model.set_parameter_values(cPickle.load(f)) except IOError: pass @@ -14,13 +14,18 @@ from theano.tensor.shared_randomstreams import RandomStreams from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -from blocks.extras.extensions.plot import Plot from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model from blocks.algorithms import GradientDescent, StepRule, CompositeRule +try: + from blocks.extras.extensions.plot import Plot + plot_avail = True +except ImportError: + plot_avail = False + import datastream from paramsaveload import SaveLoadParams from gentext import GenText @@ -63,18 +68,34 @@ class ResetStates(SimpleExtension): def train_model(m, train_stream, dump_path=None): # Define the model - model = Model(m.cost) + model = Model(m.sgd_cost) - cg = ComputationGraph(m.cost_reg) - algorithm = GradientDescent(cost=m.cost_reg, + cg = ComputationGraph(m.sgd_cost) + algorithm = GradientDescent(cost=m.sgd_cost, step_rule=CompositeRule([ ElementwiseRemoveNotFinite(), config.step_rule]), - params=cg.parameters) + parameters=cg.parameters) algorithm.add_updates(m.states) - extensions = [] + monitor_vars = [v for p in m.monitor_vars for v in p] + extensions = [ + TrainingDataMonitoring( + monitor_vars, + prefix='train', every_n_epochs=1), + Printing(every_n_epochs=1, after_epoch=False), + + ResetStates([v for v, _ in m.states], after_epoch=True) + ] + if plot_avail: + plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars] + extensions.append( + Plot(document='text_'+model_name+'_'+config.param_desc, + channels=plot_channels, + server_url='http://eos6:5006/', + every_n_epochs=1, after_epoch=False) + ) if config.save_freq is not None and dump_path is not None: extensions.append( SaveLoadParams(path=dump_path+'.pkl', @@ -105,19 +126,7 @@ def train_model(m, train_stream, dump_path=None): model=model, data_stream=train_stream, algorithm=algorithm, - extensions=extensions + [ - TrainingDataMonitoring( - [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], - prefix='train', every_n_epochs=1), - Printing(every_n_epochs=1, after_epoch=False), - Plot(document='text_'+model_name+'_'+config.param_desc, - channels=[['train_cost', 'train_cost_reg'], - ['train_error_rate', 'train_error_rate_reg']], - server_url='http://eos21:4201/', - every_n_epochs=1, after_epoch=False), - - ResetStates([v for v, _ in m.states], after_epoch=True) - ] + extensions=extensions ) main_loop.run() @@ -131,10 +140,6 @@ if __name__ == "__main__": # Build model m = config.Model() - m.cost.name = 'cost' - m.cost_reg.name = 'cost_reg' - m.error_rate.name = 'error_rate' - m.error_rate_reg.name = 'error_rate_reg' m.pred.name = 'pred' # Train the model |