summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gfgru.py25
-rw-r--r--paramsaveload.py4
-rwxr-xr-xtrain.py15
3 files changed, 22 insertions, 22 deletions
diff --git a/gfgru.py b/gfgru.py
index c4b4b48..8f05d46 100644
--- a/gfgru.py
+++ b/gfgru.py
@@ -15,24 +15,23 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout
# An epoch will be composed of 'num_seqs' sequences of len 'seq_len'
# divided in chunks of lengh 'seq_div_size'
-num_seqs = 2
-seq_len = 2
-seq_div_size = 2
+num_seqs = 10
+seq_len = 2000
+seq_div_size = 200
io_dim = 256
recurrent_blocks = [
# (256, Tanh(), [2048], [Rectifier()]),
- (384, Tanh(), [], []),
- (384, Tanh(), [], []),
- (384, Tanh(), [1024], [Rectifier()]),
-# (384, Tanh(), [1024], [Rectifier()]),
+ (512, Tanh(), [], []),
+ (512, Tanh(), [1024], [Rectifier()]),
+ (512, Tanh(), [], []),
# (2, Tanh(), [2], [Rectifier()]),
# (2, Tanh(), [], []),
]
-control_hidden = [1024]
-control_hidden_activations = [Rectifier()]
+control_hidden = [512]
+control_hidden_activations = [Tanh()]
output_hidden = [1024]
output_hidden_activations = [Rectifier()]
@@ -55,13 +54,13 @@ param_desc = '%s,c%s,o%s-n%s-d%s,%s-%dx%d(%d)-%s' % (
step_rule
)
-save_freq = 1
+save_freq = 5
on_irc = False
# parameters for sample generation
sample_len = 100
sample_temperature = 0.7 #0.5
-sample_freq = None
+sample_freq = 1
if step_rule == 'rmsprop':
step_rule = RMSProp()
@@ -95,7 +94,7 @@ class GFGRU(BaseRecurrent, Initializable):
self.hidden_total_dim = sum(x for (x, _, _, _) in self.recurrent_blocks)
# control block
- self.cblocklen = len(self.recurrent_blocks) + 2
+ self.cblocklen = len(self.recurrent_blocks) + 3
control_idim = self.hidden_total_dim + self.input_dim
control_odim = len(self.recurrent_blocks) * self.cblocklen
@@ -170,7 +169,7 @@ class GFGRU(BaseRecurrent, Initializable):
zgate_v = zgate.apply(inter_v)
nstate_v = nstate.apply(inter_v)
- zctl = zgate_v * controls[:, -2][:, None]
+ zctl = zgate_v * controls[:, -2][:, None] + controls[:, -3][:, None]
nstate_v = zctl * nstate_v + (1 - zctl) * states[i]
new_states.append(nstate_v)
diff --git a/paramsaveload.py b/paramsaveload.py
index 7181e9a..e44889d 100644
--- a/paramsaveload.py
+++ b/paramsaveload.py
@@ -19,13 +19,13 @@ class SaveLoadParams(SimpleExtension):
def do_save(self):
with open(self.path, 'w') as f:
logger.info('Saving parameters to %s...'%self.path)
- cPickle.dump(model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+ cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
def do_load(self):
try:
with open(self.path, 'r') as f:
logger.info('Loading parameters from %s...'%self.path)
- model.set_parma_values(cPickle.load(f))
+ self.model.set_param_values(cPickle.load(f))
except IOError:
pass
diff --git a/train.py b/train.py
index a188541..3ac24e7 100755
--- a/train.py
+++ b/train.py
@@ -14,7 +14,7 @@ from theano.tensor.shared_randomstreams import RandomStreams
from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER
from blocks.extensions import Printing, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-# from blocks.extras.extensions.plot import Plot
+from blocks.extras.extensions.plot import Plot
from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
@@ -64,9 +64,10 @@ def train_model(m, train_stream, dump_path=None):
extensions = []
if config.save_freq is not None and dump_path is not None:
extensions.append(
- SaveLoadParams(path=dump_path,
+ SaveLoadParams(path=dump_path+'.pkl',
model=model,
before_training=True,
+ after_training=True,
after_epoch=False,
every_n_epochs=config.save_freq)
)
@@ -96,11 +97,11 @@ def train_model(m, train_stream, dump_path=None):
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
Printing(every_n_epochs=1, after_epoch=False),
- # Plot(document='text_'+model_name+'_'+config.param_desc,
- # channels=[['train_cost', 'train_cost_reg'],
- # ['train_error_rate', 'train_error_rate_reg']],
- # server_url='http://eos21:4201/',
- # every_n_epochs=1, after_epoch=False),
+ Plot(document='text_'+model_name+'_'+config.param_desc,
+ channels=[['train_cost', 'train_cost_reg'],
+ ['train_error_rate', 'train_error_rate_reg']],
+ server_url='http://eos21:4201/',
+ every_n_epochs=1, after_epoch=False),
ResetStates([v for v, _ in m.states], after_epoch=True)
]