summaryrefslogtreecommitdiff
path: root/gfgru.py
diff options
context:
space:
mode:
Diffstat (limited to 'gfgru.py')
-rw-r--r--gfgru.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/gfgru.py b/gfgru.py
index c4b4b48..8f05d46 100644
--- a/gfgru.py
+++ b/gfgru.py
@@ -15,24 +15,23 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout
# An epoch will be composed of 'num_seqs' sequences of len 'seq_len'
# divided in chunks of lengh 'seq_div_size'
-num_seqs = 2
-seq_len = 2
-seq_div_size = 2
+num_seqs = 10
+seq_len = 2000
+seq_div_size = 200
io_dim = 256
recurrent_blocks = [
# (256, Tanh(), [2048], [Rectifier()]),
- (384, Tanh(), [], []),
- (384, Tanh(), [], []),
- (384, Tanh(), [1024], [Rectifier()]),
-# (384, Tanh(), [1024], [Rectifier()]),
+ (512, Tanh(), [], []),
+ (512, Tanh(), [1024], [Rectifier()]),
+ (512, Tanh(), [], []),
# (2, Tanh(), [2], [Rectifier()]),
# (2, Tanh(), [], []),
]
-control_hidden = [1024]
-control_hidden_activations = [Rectifier()]
+control_hidden = [512]
+control_hidden_activations = [Tanh()]
output_hidden = [1024]
output_hidden_activations = [Rectifier()]
@@ -55,13 +54,13 @@ param_desc = '%s,c%s,o%s-n%s-d%s,%s-%dx%d(%d)-%s' % (
step_rule
)
-save_freq = 1
+save_freq = 5
on_irc = False
# parameters for sample generation
sample_len = 100
sample_temperature = 0.7 #0.5
-sample_freq = None
+sample_freq = 1
if step_rule == 'rmsprop':
step_rule = RMSProp()
@@ -95,7 +94,7 @@ class GFGRU(BaseRecurrent, Initializable):
self.hidden_total_dim = sum(x for (x, _, _, _) in self.recurrent_blocks)
# control block
- self.cblocklen = len(self.recurrent_blocks) + 2
+ self.cblocklen = len(self.recurrent_blocks) + 3
control_idim = self.hidden_total_dim + self.input_dim
control_odim = len(self.recurrent_blocks) * self.cblocklen
@@ -170,7 +169,7 @@ class GFGRU(BaseRecurrent, Initializable):
zgate_v = zgate.apply(inter_v)
nstate_v = nstate.apply(inter_v)
- zctl = zgate_v * controls[:, -2][:, None]
+ zctl = zgate_v * controls[:, -2][:, None] + controls[:, -3][:, None]
nstate_v = zctl * nstate_v + (1 - zctl) * states[i]
new_states.append(nstate_v)