summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2015-06-17 14:58:38 -0400
committerAlex Auvolat <alex@adnab.me>2015-06-17 14:58:52 -0400
commit12304944033d20bbc5c1b3f5cb90cf8dedebcdff (patch)
tree1097585f948d7040416eef58344b4bd194f10b9f
parent701c407b8c87a9270b31d34ac54e683341be661e (diff)
downloadtext-rnn-12304944033d20bbc5c1b3f5cb90cf8dedebcdff.tar.gz
text-rnn-12304944033d20bbc5c1b3f5cb90cf8dedebcdff.zip
paramsaveload
-rw-r--r--paramsaveload.py37
-rwxr-xr-xtrain.py49
2 files changed, 69 insertions, 17 deletions
diff --git a/paramsaveload.py b/paramsaveload.py
new file mode 100644
index 0000000..7181e9a
--- /dev/null
+++ b/paramsaveload.py
@@ -0,0 +1,37 @@
+import logging
+
+import numpy
+
+import cPickle
+
+from blocks.extensions import SimpleExtension
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger('extensions.SaveLoadParams')
+
+class SaveLoadParams(SimpleExtension):
+ def __init__(self, path, model, **kwargs):
+ super(SaveLoadParams, self).__init__(**kwargs)
+
+ self.path = path
+ self.model = model
+
+ def do_save(self):
+ with open(self.path, 'w') as f:
+ logger.info('Saving parameters to %s...'%self.path)
+ cPickle.dump(model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+
+ def do_load(self):
+ try:
+ with open(self.path, 'r') as f:
+ logger.info('Loading parameters from %s...'%self.path)
+ model.set_parma_values(cPickle.load(f))
+ except IOError:
+ pass
+
+ def do(self, which_callback, *args):
+ if which_callback == 'before_training':
+ self.do_load()
+ else:
+ self.do_save()
+
diff --git a/train.py b/train.py
index 79b2116..525724f 100755
--- a/train.py
+++ b/train.py
@@ -14,7 +14,7 @@ from theano.tensor.shared_randomstreams import RandomStreams
from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER
from blocks.extensions import Printing, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-from blocks.extras.extensions.plot import Plot
+# from blocks.extras.extensions.plot import Plot
from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
@@ -22,6 +22,7 @@ from blocks.model import Model
from blocks.algorithms import GradientDescent
import datastream
+import paramsaveload
import gentext
import ircext
@@ -60,17 +61,31 @@ def train_model(m, train_stream, dump_path=None):
algorithm.add_updates(m.states)
- # Load the parameters from a dumped model
- if dump_path is not None:
- try:
- with closing(numpy.load(dump_path)) as source:
- logger.info('Loading parameters...')
- param_values = {'/' + name.replace(BRICK_DELIMITER, '/'): source[name]
- for name in source.keys()
- if name != 'pkl' and not 'None' in name}
- model.set_param_values(param_values)
- except IOError:
- pass
+ extensions = []
+ if config.save_freq is not None and dump_path is not None:
+ extensions.append(
+ SaveLoadParams(path=dump_path,
+ model=model,
+ before_training=True,
+ after_epoch=False,
+ every_n_epochs=config.save_freq)
+ )
+ if config.sample_freq is not None:
+ extensions.append(
+ gentext.GenText(m, '\nalex\ttu crois ?\n',
+ config.sample_len, config.sample_temperature,
+ every_n_epochs=config.sample_freq,
+ after_epoch=False, before_training=True)
+ )
+ if config.on_irc:
+ extensions.append(
+ ircext.IRCClientExt(m, config.sample_temperature,
+ server='irc.ulminfo.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True)
+ )
extensions = []
if config.save_freq is not None:
@@ -106,11 +121,11 @@ def train_model(m, train_stream, dump_path=None):
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
Printing(every_n_epochs=1, after_epoch=False),
- Plot(document='text_'+model_name+'_'+config.param_desc,
- channels=[['train_cost', 'train_cost_reg'],
- ['train_error_rate', 'train_error_rate_reg']],
- server_url='http://eos21:4201/',
- every_n_epochs=1, after_epoch=False),
+ # Plot(document='text_'+model_name+'_'+config.param_desc,
+ # channels=[['train_cost', 'train_cost_reg'],
+ # ['train_error_rate', 'train_error_rate_reg']],
+ # server_url='http://eos21:4201/',
+ # every_n_epochs=1, after_epoch=False),
ResetStates([v for v, _ in m.states], after_epoch=True)
]