diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 17:25:19 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 17:25:19 -0400 |
commit | 8b9f95399e7b23aed493c7a67a9b56c5193ad53a (patch) | |
tree | f58a1cc0f97ff1785192972549ef6a8129fcc01a /gfgru.py | |
parent | 0ba1bd24fd2375fc4de5d355e434f747c03de202 (diff) | |
download | text-rnn-8b9f95399e7b23aed493c7a67a9b56c5193ad53a.tar.gz text-rnn-8b9f95399e7b23aed493c7a67a9b56c5193ad53a.zip |
xoxo
Diffstat (limited to 'gfgru.py')
-rw-r--r-- | gfgru.py | 25 |
1 files changed, 12 insertions, 13 deletions
@@ -15,24 +15,23 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout # An epoch will be composed of 'num_seqs' sequences of len 'seq_len' # divided in chunks of lengh 'seq_div_size' -num_seqs = 2 -seq_len = 2 -seq_div_size = 2 +num_seqs = 10 +seq_len = 2000 +seq_div_size = 200 io_dim = 256 recurrent_blocks = [ # (256, Tanh(), [2048], [Rectifier()]), - (384, Tanh(), [], []), - (384, Tanh(), [], []), - (384, Tanh(), [1024], [Rectifier()]), -# (384, Tanh(), [1024], [Rectifier()]), + (512, Tanh(), [], []), + (512, Tanh(), [1024], [Rectifier()]), + (512, Tanh(), [], []), # (2, Tanh(), [2], [Rectifier()]), # (2, Tanh(), [], []), ] -control_hidden = [1024] -control_hidden_activations = [Rectifier()] +control_hidden = [512] +control_hidden_activations = [Tanh()] output_hidden = [1024] output_hidden_activations = [Rectifier()] @@ -55,13 +54,13 @@ param_desc = '%s,c%s,o%s-n%s-d%s,%s-%dx%d(%d)-%s' % ( step_rule ) -save_freq = 1 +save_freq = 5 on_irc = False # parameters for sample generation sample_len = 100 sample_temperature = 0.7 #0.5 -sample_freq = None +sample_freq = 1 if step_rule == 'rmsprop': step_rule = RMSProp() @@ -95,7 +94,7 @@ class GFGRU(BaseRecurrent, Initializable): self.hidden_total_dim = sum(x for (x, _, _, _) in self.recurrent_blocks) # control block - self.cblocklen = len(self.recurrent_blocks) + 2 + self.cblocklen = len(self.recurrent_blocks) + 3 control_idim = self.hidden_total_dim + self.input_dim control_odim = len(self.recurrent_blocks) * self.cblocklen @@ -170,7 +169,7 @@ class GFGRU(BaseRecurrent, Initializable): zgate_v = zgate.apply(inter_v) nstate_v = nstate.apply(inter_v) - zctl = zgate_v * controls[:, -2][:, None] + zctl = zgate_v * controls[:, -2][:, None] + controls[:, -3][:, None] nstate_v = zctl * nstate_v + (1 - zctl) * states[i] new_states.append(nstate_v) |