diff options
author | Alex Auvolat <alex@adnab.me> | 2016-03-09 09:55:43 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2016-03-09 09:55:43 +0100 |
commit | 62c05c06013e7204c1e7681a7e2ac7541f2acbcb (patch) | |
tree | d5f6825e4747ecf0eebba0e11d0bd3c0ec31b764 /model | |
parent | bb9ebdeee88409a209d1bcc04a374e4b7d7e13d2 (diff) | |
download | text-rnn-62c05c06013e7204c1e7681a7e2ac7541f2acbcb.tar.gz text-rnn-62c05c06013e7204c1e7681a7e2ac7541f2acbcb.zip |
Very nice model
Diffstat (limited to 'model')
-rw-r--r-- | model/hpc_lstm.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/model/hpc_lstm.py b/model/hpc_lstm.py index 5bad8af..395646c 100644 --- a/model/hpc_lstm.py +++ b/model/hpc_lstm.py @@ -17,9 +17,10 @@ class Model(): def __init__(self, config): inp = tensor.imatrix('bytes') - in_onehot = tensor.eq(tensor.arange(config.io_dim, dtype='int16').reshape((1, 1, config.io_dim)), - inp[:, :, None]) - in_onehot.name = 'in_onehot' + embed = theano.shared(config.embedding_matrix.astype(theano.config.floatX), + name='embedding_matrix') + in_repr = embed[inp.flatten(), :].reshape((inp.shape[0], inp.shape[1], config.repr_dim)) + in_repr.name = 'in_repr' bricks = [] states = [] @@ -27,21 +28,20 @@ class Model(): # Construct predictive LSTM hierarchy hidden = [] costs = [] - next_target = in_onehot.dimshuffle(1, 0, 2) - for i, (hdim, cf, q, esf) in enumerate(zip(config.hidden_dims, + next_target = in_repr.dimshuffle(1, 0, 2) + for i, (hdim, cf, q) in enumerate(zip(config.hidden_dims, config.cost_factors, - config.hidden_q, - config.error_scale_factor)): + config.hidden_q)): init_state = theano.shared(numpy.zeros((config.num_seqs, hdim)).astype(theano.config.floatX), name='st0_%d'%i) init_cell = theano.shared(numpy.zeros((config.num_seqs, hdim)).astype(theano.config.floatX), name='cell0_%d'%i) - linear = Linear(input_dim=config.io_dim, output_dim=4*hdim, + linear = Linear(input_dim=config.repr_dim, output_dim=4*hdim, name="lstm_in_%d"%i) lstm = LSTM(dim=hdim, activation=config.activation_function, name="lstm_rec_%d"%i) - linear2 = Linear(input_dim=hdim, output_dim=config.io_dim, name='lstm_out_%d'%i) + linear2 = Linear(input_dim=hdim, output_dim=config.repr_dim, name='lstm_out_%d'%i) tanh = Tanh('lstm_out_tanh_%d'%i) bricks += [linear, lstm, linear2, tanh] if i > 0: |