summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-17 17:25:19 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-17 17:25:19 -0400
commit8b9f95399e7b23aed493c7a67a9b56c5193ad53a (patch)
treef58a1cc0f97ff1785192972549ef6a8129fcc01a /train.py
parent0ba1bd24fd2375fc4de5d355e434f747c03de202 (diff)
downloadtext-rnn-8b9f95399e7b23aed493c7a67a9b56c5193ad53a.tar.gz
text-rnn-8b9f95399e7b23aed493c7a67a9b56c5193ad53a.zip
xoxo
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/train.py b/train.py
index a188541..3ac24e7 100755
--- a/train.py
+++ b/train.py
@@ -14,7 +14,7 @@ from theano.tensor.shared_randomstreams import RandomStreams
from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER
from blocks.extensions import Printing, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-# from blocks.extras.extensions.plot import Plot
+from blocks.extras.extensions.plot import Plot
from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
@@ -64,9 +64,10 @@ def train_model(m, train_stream, dump_path=None):
extensions = []
if config.save_freq is not None and dump_path is not None:
extensions.append(
- SaveLoadParams(path=dump_path,
+ SaveLoadParams(path=dump_path+'.pkl',
model=model,
before_training=True,
+ after_training=True,
after_epoch=False,
every_n_epochs=config.save_freq)
)
@@ -96,11 +97,11 @@ def train_model(m, train_stream, dump_path=None):
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
Printing(every_n_epochs=1, after_epoch=False),
- # Plot(document='text_'+model_name+'_'+config.param_desc,
- # channels=[['train_cost', 'train_cost_reg'],
- # ['train_error_rate', 'train_error_rate_reg']],
- # server_url='http://eos21:4201/',
- # every_n_epochs=1, after_epoch=False),
+ Plot(document='text_'+model_name+'_'+config.param_desc,
+ channels=[['train_cost', 'train_cost_reg'],
+ ['train_error_rate', 'train_error_rate_reg']],
+ server_url='http://eos21:4201/',
+ every_n_epochs=1, after_epoch=False),
ResetStates([v for v, _ in m.states], after_epoch=True)
]