diff options
Diffstat (limited to 'lstm.py')
-rw-r--r-- | lstm.py | 75 |
1 files changed, 48 insertions, 27 deletions
@@ -13,36 +13,48 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout # An epoch will be composed of 'num_seqs' sequences of len 'seq_len' # divided in chunks of lengh 'seq_div_size' -num_seqs = 10 +num_seqs = 20 seq_len = 2000 seq_div_size = 100 io_dim = 256 -hidden_dims = [512, 512] +hidden_dims = [512, 512, 512] activation_function = Tanh() -all_hidden_for_output = False +i2h_all = True # input to all hidden layers or only first layer +h2o_all = True # all hiden layers to output or only last layer w_noise_std = 0.01 i_dropout = 0.5 -step_rule = 'adadelta' +step_rule = 'momentum' +learning_rate = 0.1 +momentum = 0.9 -param_desc = '%s-%sHO-n%s-d%s-%dx%d(%d)-%s' % ( +param_desc = '%s-%sIH,%sHO-n%s-d%s-%dx%d(%d)-%s' % ( repr(hidden_dims), - 'all' if all_hidden_for_output else 'last', + 'all' if i2h_all else 'first', + 'all' if h2o_all else 'last', repr(w_noise_std), repr(i_dropout), num_seqs, seq_len, seq_div_size, step_rule ) +save_freq = 5 + +# parameters for sample generation +sample_len = 60 +sample_temperature = 0.3 + if step_rule == 'rmsprop': step_rule = RMSProp() elif step_rule == 'adadelta': step_rule = AdaDelta() +elif step_rule == 'momentum': + step_rule = Momentum(learning_rate=learning_rate, momentum=momentum) else: assert(False) @@ -52,7 +64,9 @@ class Model(): in_onehot = tensor.eq(tensor.arange(io_dim, dtype='int16').reshape((1, 1, io_dim)), inp[:, :, None]) + in_onehot.name = 'in_onehot' + # Construct hidden states dims = [io_dim] + hidden_dims states = [in_onehot.dimshuffle(1, 0, 2)] bricks = [] @@ -65,38 +79,44 @@ class Model(): linear = Linear(input_dim=dims[i-1], output_dim=4*dims[i], name="lstm_in_%d"%i) + bricks.append(linear) + inter = linear.apply(states[-1]) + + if i2h_all and i > 1: + linear2 = Linear(input_dim=dims[0], output_dim=4*dims[i], + name="lstm_in0_%d"%i) + bricks.append(linear2) + inter = inter + linear2.apply(states[0]) + inter.name = 'inter_bis_%d'%i + lstm = LSTM(dim=dims[i], activation=activation_function, name="lstm_rec_%d"%i) + bricks.append(lstm) - new_states, new_cells = lstm.apply(linear.apply(states[-1]), + new_states, new_cells = lstm.apply(inter, states=init_state, cells=init_cell) updates.append((init_state, new_states[-1, :, :])) updates.append((init_cell, new_cells[-1, :, :])) states.append(new_states) - bricks = bricks + [linear, lstm] - states = [s.dimshuffle(1, 0, 2).reshape((inp.shape[0] * inp.shape[1], dim)) - for dim, s in zip(dims, states)] + states = [s.dimshuffle(1, 0, 2) for s in states] - if all_hidden_for_output: - top_linear = MLP(dims=[sum(hidden_dims), io_dim], - activations=[Softmax()], - name="pred_mlp") + # Construct output from hidden states + out = None + layers = zip(dims, states)[1:] + if not h2o_all: + layers = [layers[-1]] + for i, (dim, state) in enumerate(layers): + top_linear = Linear(input_dim=dim, output_dim=io_dim, + name='top_linear_%d'%i) bricks.append(top_linear) + out_i = top_linear.apply(state) + out = out_i if out is None else out + out_i + out.name = 'out_part_%d'%i - out = top_linear.apply(tensor.concatenate(states[1:], axis=1)) - else: - top_linear = MLP(dims=[hidden_dims[-1], io_dim], - activations=[None], - name="pred_mlp") - bricks.append(top_linear) - - out = top_linear.apply(states[-1]) - - out = out.reshape((inp.shape[0], inp.shape[1], io_dim)) - + # Do prediction and calculate cost pred = out.argmax(axis=2) cost = Softmax().categorical_cross_entropy(inp[:, 1:].flatten(), @@ -104,13 +124,13 @@ class Model(): io_dim))) error_rate = tensor.neq(inp[:, 1:].flatten(), pred[:, :-1].flatten()).mean() - # Initialize + # Initialize all bricks for brick in bricks: brick.weights_init = IsotropicGaussian(0.1) brick.biases_init = Constant(0.) brick.initialize() - # apply noise + # Apply noise and dropout cg = ComputationGraph([cost, error_rate]) if w_noise_std > 0: noise_vars = VariableFilter(roles=[WEIGHT])(cg) @@ -123,6 +143,7 @@ class Model(): self.error_rate = error_rate self.cost_reg = cost_reg self.error_rate_reg = error_rate_reg + self.out = out self.pred = pred self.updates = updates |