diff options
-rw-r--r-- | gfgru.py | 25 | ||||
-rw-r--r-- | paramsaveload.py | 4 | ||||
-rwxr-xr-x | train.py | 15 |
3 files changed, 22 insertions, 22 deletions
@@ -15,24 +15,23 @@ from blocks.graph import ComputationGraph, apply_noise, apply_dropout # An epoch will be composed of 'num_seqs' sequences of len 'seq_len' # divided in chunks of lengh 'seq_div_size' -num_seqs = 2 -seq_len = 2 -seq_div_size = 2 +num_seqs = 10 +seq_len = 2000 +seq_div_size = 200 io_dim = 256 recurrent_blocks = [ # (256, Tanh(), [2048], [Rectifier()]), - (384, Tanh(), [], []), - (384, Tanh(), [], []), - (384, Tanh(), [1024], [Rectifier()]), -# (384, Tanh(), [1024], [Rectifier()]), + (512, Tanh(), [], []), + (512, Tanh(), [1024], [Rectifier()]), + (512, Tanh(), [], []), # (2, Tanh(), [2], [Rectifier()]), # (2, Tanh(), [], []), ] -control_hidden = [1024] -control_hidden_activations = [Rectifier()] +control_hidden = [512] +control_hidden_activations = [Tanh()] output_hidden = [1024] output_hidden_activations = [Rectifier()] @@ -55,13 +54,13 @@ param_desc = '%s,c%s,o%s-n%s-d%s,%s-%dx%d(%d)-%s' % ( step_rule ) -save_freq = 1 +save_freq = 5 on_irc = False # parameters for sample generation sample_len = 100 sample_temperature = 0.7 #0.5 -sample_freq = None +sample_freq = 1 if step_rule == 'rmsprop': step_rule = RMSProp() @@ -95,7 +94,7 @@ class GFGRU(BaseRecurrent, Initializable): self.hidden_total_dim = sum(x for (x, _, _, _) in self.recurrent_blocks) # control block - self.cblocklen = len(self.recurrent_blocks) + 2 + self.cblocklen = len(self.recurrent_blocks) + 3 control_idim = self.hidden_total_dim + self.input_dim control_odim = len(self.recurrent_blocks) * self.cblocklen @@ -170,7 +169,7 @@ class GFGRU(BaseRecurrent, Initializable): zgate_v = zgate.apply(inter_v) nstate_v = nstate.apply(inter_v) - zctl = zgate_v * controls[:, -2][:, None] + zctl = zgate_v * controls[:, -2][:, None] + controls[:, -3][:, None] nstate_v = zctl * nstate_v + (1 - zctl) * states[i] new_states.append(nstate_v) diff --git a/paramsaveload.py b/paramsaveload.py index 7181e9a..e44889d 100644 --- a/paramsaveload.py +++ b/paramsaveload.py @@ -19,13 +19,13 @@ class SaveLoadParams(SimpleExtension): def do_save(self): with open(self.path, 'w') as f: logger.info('Saving parameters to %s...'%self.path) - cPickle.dump(model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) def do_load(self): try: with open(self.path, 'r') as f: logger.info('Loading parameters from %s...'%self.path) - model.set_parma_values(cPickle.load(f)) + self.model.set_param_values(cPickle.load(f)) except IOError: pass @@ -14,7 +14,7 @@ from theano.tensor.shared_randomstreams import RandomStreams from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -# from blocks.extras.extensions.plot import Plot +from blocks.extras.extensions.plot import Plot from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop @@ -64,9 +64,10 @@ def train_model(m, train_stream, dump_path=None): extensions = [] if config.save_freq is not None and dump_path is not None: extensions.append( - SaveLoadParams(path=dump_path, + SaveLoadParams(path=dump_path+'.pkl', model=model, before_training=True, + after_training=True, after_epoch=False, every_n_epochs=config.save_freq) ) @@ -96,11 +97,11 @@ def train_model(m, train_stream, dump_path=None): [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], prefix='train', every_n_epochs=1), Printing(every_n_epochs=1, after_epoch=False), - # Plot(document='text_'+model_name+'_'+config.param_desc, - # channels=[['train_cost', 'train_cost_reg'], - # ['train_error_rate', 'train_error_rate_reg']], - # server_url='http://eos21:4201/', - # every_n_epochs=1, after_epoch=False), + Plot(document='text_'+model_name+'_'+config.param_desc, + channels=[['train_cost', 'train_cost_reg'], + ['train_error_rate', 'train_error_rate_reg']], + server_url='http://eos21:4201/', + every_n_epochs=1, after_epoch=False), ResetStates([v for v, _ in m.states], after_epoch=True) ] |