summaryrefslogtreecommitdiff
path: root/model/lstm.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2016-03-08 13:26:28 +0100
committerAlex Auvolat <alex@adnab.me>2016-03-08 13:26:28 +0100
commit2f479926c16d2911d0dd878c21de082abfc5b237 (patch)
treeb399e9ad9af04a9449334dff1a47449808b7ca13 /model/lstm.py
parent23093608e0edc43477c3a2ed804ae1016790f7e4 (diff)
downloadtext-rnn-2f479926c16d2911d0dd878c21de082abfc5b237.tar.gz
text-rnn-2f479926c16d2911d0dd878c21de082abfc5b237.zip
Revive project
Diffstat (limited to 'model/lstm.py')
-rw-r--r--model/lstm.py119
1 files changed, 119 insertions, 0 deletions
diff --git a/model/lstm.py b/model/lstm.py
new file mode 100644
index 0000000..abd44e0
--- /dev/null
+++ b/model/lstm.py
@@ -0,0 +1,119 @@
+import theano
+from theano import tensor
+import numpy
+
+from blocks.bricks import Softmax, Linear
+from blocks.bricks.recurrent import LSTM
+from blocks.initialization import IsotropicGaussian, Constant
+
+from blocks.filter import VariableFilter
+from blocks.roles import WEIGHT
+from blocks.graph import ComputationGraph, apply_noise, apply_dropout
+
+
+class Model():
+ def __init__(self, config):
+ inp = tensor.imatrix('bytes')
+
+ in_onehot = tensor.eq(tensor.arange(config.io_dim, dtype='int16').reshape((1, 1, config.io_dim)),
+ inp[:, :, None])
+ in_onehot.name = 'in_onehot'
+
+ # Construct hidden states
+ dims = [config.io_dim] + config.hidden_dims
+ hidden = [in_onehot.dimshuffle(1, 0, 2)]
+ bricks = []
+ states = []
+ for i in xrange(1, len(dims)):
+ init_state = theano.shared(numpy.zeros((config.num_seqs, dims[i])).astype(theano.config.floatX),
+ name='st0_%d'%i)
+ init_cell = theano.shared(numpy.zeros((config.num_seqs, dims[i])).astype(theano.config.floatX),
+ name='cell0_%d'%i)
+
+ linear = Linear(input_dim=dims[i-1], output_dim=4*dims[i],
+ name="lstm_in_%d"%i)
+ bricks.append(linear)
+ inter = linear.apply(hidden[-1])
+
+ if config.i2h_all and i > 1:
+ linear2 = Linear(input_dim=dims[0], output_dim=4*dims[i],
+ name="lstm_in0_%d"%i)
+ bricks.append(linear2)
+ inter = inter + linear2.apply(hidden[0])
+ inter.name = 'inter_bis_%d'%i
+
+ lstm = LSTM(dim=dims[i], activation=config.activation_function,
+ name="lstm_rec_%d"%i)
+ bricks.append(lstm)
+
+ new_hidden, new_cells = lstm.apply(inter,
+ states=init_state,
+ cells=init_cell)
+ states.append((init_state, new_hidden[-1, :, :]))
+ states.append((init_cell, new_cells[-1, :, :]))
+
+ hidden.append(new_hidden)
+
+ hidden = [s.dimshuffle(1, 0, 2) for s in hidden]
+
+ # Construct output from hidden states
+ out = None
+ layers = zip(dims, hidden)[1:]
+ if not config.h2o_all:
+ layers = [layers[-1]]
+ for i, (dim, state) in enumerate(layers):
+ top_linear = Linear(input_dim=dim, output_dim=config.io_dim,
+ name='top_linear_%d'%i)
+ bricks.append(top_linear)
+ out_i = top_linear.apply(state)
+ out = out_i if out is None else out + out_i
+ out.name = 'out_part_%d'%i
+
+ # Do prediction and calculate cost
+ pred = out.argmax(axis=2)
+
+ cost = Softmax().categorical_cross_entropy(inp[:, 1:].flatten(),
+ out[:, :-1, :].reshape((inp.shape[0]*(inp.shape[1]-1),
+ config.io_dim))).mean()
+ error_rate = tensor.neq(inp[:, 1:].flatten(), pred[:, :-1].flatten()).mean()
+
+ # Initialize all bricks
+ for brick in bricks:
+ brick.weights_init = IsotropicGaussian(0.1)
+ brick.biases_init = Constant(0.)
+ brick.initialize()
+
+ # Apply noise and dropout
+ cg = ComputationGraph([cost, error_rate])
+ if config.w_noise_std > 0:
+ noise_vars = VariableFilter(roles=[WEIGHT])(cg)
+ cg = apply_noise(cg, noise_vars, config.w_noise_std)
+ if config.i_dropout > 0:
+ cg = apply_dropout(cg, hidden[1:], config.i_dropout)
+ [cost_reg, error_rate_reg] = cg.outputs
+
+ # add l1 regularization
+ if config.l1_reg > 0:
+ l1pen = sum(abs(st).mean() for st in hidden[1:])
+ cost_reg = cost_reg + config.l1_reg * l1pen
+
+ cost_reg += 1e-10 # so that it is not the same Theano variable
+ error_rate_reg += 1e-10
+
+ # put stuff into self that is usefull for training or extensions
+ self.sgd_cost = cost_reg
+
+ cost.name = 'cost'
+ cost_reg.name = 'cost_reg'
+ error_rate.name = 'error_rate'
+ error_rate_reg.name = 'error_rate_reg'
+ self.monitor_vars = [[cost, cost_reg],
+ [error_rate, error_rate_reg]]
+
+ self.out = out
+ self.pred = pred
+
+ self.states = states
+
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :