diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 17:25:19 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 17:25:19 -0400 |
commit | 8b9f95399e7b23aed493c7a67a9b56c5193ad53a (patch) | |
tree | f58a1cc0f97ff1785192972549ef6a8129fcc01a /train.py | |
parent | 0ba1bd24fd2375fc4de5d355e434f747c03de202 (diff) | |
download | text-rnn-8b9f95399e7b23aed493c7a67a9b56c5193ad53a.tar.gz text-rnn-8b9f95399e7b23aed493c7a67a9b56c5193ad53a.zip |
xoxo
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 15 |
1 files changed, 8 insertions, 7 deletions
@@ -14,7 +14,7 @@ from theano.tensor.shared_randomstreams import RandomStreams from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -# from blocks.extras.extensions.plot import Plot +from blocks.extras.extensions.plot import Plot from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop @@ -64,9 +64,10 @@ def train_model(m, train_stream, dump_path=None): extensions = [] if config.save_freq is not None and dump_path is not None: extensions.append( - SaveLoadParams(path=dump_path, + SaveLoadParams(path=dump_path+'.pkl', model=model, before_training=True, + after_training=True, after_epoch=False, every_n_epochs=config.save_freq) ) @@ -96,11 +97,11 @@ def train_model(m, train_stream, dump_path=None): [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], prefix='train', every_n_epochs=1), Printing(every_n_epochs=1, after_epoch=False), - # Plot(document='text_'+model_name+'_'+config.param_desc, - # channels=[['train_cost', 'train_cost_reg'], - # ['train_error_rate', 'train_error_rate_reg']], - # server_url='http://eos21:4201/', - # every_n_epochs=1, after_epoch=False), + Plot(document='text_'+model_name+'_'+config.param_desc, + channels=[['train_cost', 'train_cost_reg'], + ['train_error_rate', 'train_error_rate_reg']], + server_url='http://eos21:4201/', + every_n_epochs=1, after_epoch=False), ResetStates([v for v, _ in m.states], after_epoch=True) ] |