summaryrefslogtreecommitdiff
path: root/model/hpc_lstm.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2016-03-09 09:55:43 +0100
committerAlex Auvolat <alex@adnab.me>2016-03-09 09:55:43 +0100
commit62c05c06013e7204c1e7681a7e2ac7541f2acbcb (patch)
treed5f6825e4747ecf0eebba0e11d0bd3c0ec31b764 /model/hpc_lstm.py
parentbb9ebdeee88409a209d1bcc04a374e4b7d7e13d2 (diff)
downloadtext-rnn-62c05c06013e7204c1e7681a7e2ac7541f2acbcb.tar.gz
text-rnn-62c05c06013e7204c1e7681a7e2ac7541f2acbcb.zip
Very nice model
Diffstat (limited to 'model/hpc_lstm.py')
-rw-r--r--model/hpc_lstm.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/model/hpc_lstm.py b/model/hpc_lstm.py
index 5bad8af..395646c 100644
--- a/model/hpc_lstm.py
+++ b/model/hpc_lstm.py
@@ -17,9 +17,10 @@ class Model():
def __init__(self, config):
inp = tensor.imatrix('bytes')
- in_onehot = tensor.eq(tensor.arange(config.io_dim, dtype='int16').reshape((1, 1, config.io_dim)),
- inp[:, :, None])
- in_onehot.name = 'in_onehot'
+ embed = theano.shared(config.embedding_matrix.astype(theano.config.floatX),
+ name='embedding_matrix')
+ in_repr = embed[inp.flatten(), :].reshape((inp.shape[0], inp.shape[1], config.repr_dim))
+ in_repr.name = 'in_repr'
bricks = []
states = []
@@ -27,21 +28,20 @@ class Model():
# Construct predictive LSTM hierarchy
hidden = []
costs = []
- next_target = in_onehot.dimshuffle(1, 0, 2)
- for i, (hdim, cf, q, esf) in enumerate(zip(config.hidden_dims,
+ next_target = in_repr.dimshuffle(1, 0, 2)
+ for i, (hdim, cf, q) in enumerate(zip(config.hidden_dims,
config.cost_factors,
- config.hidden_q,
- config.error_scale_factor)):
+ config.hidden_q)):
init_state = theano.shared(numpy.zeros((config.num_seqs, hdim)).astype(theano.config.floatX),
name='st0_%d'%i)
init_cell = theano.shared(numpy.zeros((config.num_seqs, hdim)).astype(theano.config.floatX),
name='cell0_%d'%i)
- linear = Linear(input_dim=config.io_dim, output_dim=4*hdim,
+ linear = Linear(input_dim=config.repr_dim, output_dim=4*hdim,
name="lstm_in_%d"%i)
lstm = LSTM(dim=hdim, activation=config.activation_function,
name="lstm_rec_%d"%i)
- linear2 = Linear(input_dim=hdim, output_dim=config.io_dim, name='lstm_out_%d'%i)
+ linear2 = Linear(input_dim=hdim, output_dim=config.repr_dim, name='lstm_out_%d'%i)
tanh = Tanh('lstm_out_tanh_%d'%i)
bricks += [linear, lstm, linear2, tanh]
if i > 0: