From 8b9f95399e7b23aed493c7a67a9b56c5193ad53a Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 17 Jun 2015 17:25:19 -0400 Subject: xoxo --- gfgru.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) (limited to 'gfgru.py') diff --git a/gfgru.py b/gfgru.py index c4b4b48..8f05d46 100644 --- a/gfgru.py +++ b/gfgru.py @@ -15,24 +15,23 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout # An epoch will be composed of 'num_seqs' sequences of len 'seq_len' # divided in chunks of lengh 'seq_div_size' -num_seqs = 2 -seq_len = 2 -seq_div_size = 2 +num_seqs = 10 +seq_len = 2000 +seq_div_size = 200 io_dim = 256 recurrent_blocks = [ # (256, Tanh(), [2048], [Rectifier()]), - (384, Tanh(), [], []), - (384, Tanh(), [], []), - (384, Tanh(), [1024], [Rectifier()]), -# (384, Tanh(), [1024], [Rectifier()]), + (512, Tanh(), [], []), + (512, Tanh(), [1024], [Rectifier()]), + (512, Tanh(), [], []), # (2, Tanh(), [2], [Rectifier()]), # (2, Tanh(), [], []), ] -control_hidden = [1024] -control_hidden_activations = [Rectifier()] +control_hidden = [512] +control_hidden_activations = [Tanh()] output_hidden = [1024] output_hidden_activations = [Rectifier()] @@ -55,13 +54,13 @@ param_desc = '%s,c%s,o%s-n%s-d%s,%s-%dx%d(%d)-%s' % ( step_rule ) -save_freq = 1 +save_freq = 5 on_irc = False # parameters for sample generation sample_len = 100 sample_temperature = 0.7 #0.5 -sample_freq = None +sample_freq = 1 if step_rule == 'rmsprop': step_rule = RMSProp() @@ -95,7 +94,7 @@ class GFGRU(BaseRecurrent, Initializable): self.hidden_total_dim = sum(x for (x, _, _, _) in self.recurrent_blocks) # control block - self.cblocklen = len(self.recurrent_blocks) + 2 + self.cblocklen = len(self.recurrent_blocks) + 3 control_idim = self.hidden_total_dim + self.input_dim control_odim = len(self.recurrent_blocks) * self.cblocklen @@ -170,7 +169,7 @@ class GFGRU(BaseRecurrent, Initializable): zgate_v = zgate.apply(inter_v) nstate_v = nstate.apply(inter_v) - zctl = zgate_v * controls[:, -2][:, None] + zctl = zgate_v * controls[:, -2][:, None] + controls[:, -3][:, None] nstate_v = zctl * nstate_v + (1 - zctl) * states[i] new_states.append(nstate_v) -- cgit v1.2.3