diff options
author | Alex Auvolat <alex@adnab.me> | 2015-06-17 15:02:44 -0400 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2015-06-17 15:02:44 -0400 |
commit | 0ba1bd24fd2375fc4de5d355e434f747c03de202 (patch) | |
tree | 17715acb9f482a4068d2ff5b5df26ac0be8216e8 | |
parent | 12304944033d20bbc5c1b3f5cb90cf8dedebcdff (diff) | |
download | text-rnn-0ba1bd24fd2375fc4de5d355e434f747c03de202.tar.gz text-rnn-0ba1bd24fd2375fc4de5d355e434f747c03de202.zip |
SaveLoadParams
-rwxr-xr-x | train.py | 51 |
1 files changed, 13 insertions, 38 deletions
@@ -22,9 +22,9 @@ from blocks.model import Model from blocks.algorithms import GradientDescent import datastream -import paramsaveload -import gentext -import ircext +from paramsaveload import SaveLoadParams +from gentext import GenText +from ircext import IRCClientExt logging.basicConfig(level='INFO') logger = logging.getLogger(__name__) @@ -72,44 +72,19 @@ def train_model(m, train_stream, dump_path=None): ) if config.sample_freq is not None: extensions.append( - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True) + GenText(m, '\nalex\ttu crois ?\n', + config.sample_len, config.sample_temperature, + every_n_epochs=config.sample_freq, + after_epoch=False, before_training=True) ) if config.on_irc: extensions.append( - ircext.IRCClientExt(m, config.sample_temperature, - server='irc.ulminfo.fr', - port=6667, - nick='frigo', - channels=['#frigotest', '#courssysteme'], - after_batch=True) - ) - - extensions = [] - if config.save_freq is not None: - extensions.append( - Checkpoint(path=dump_path, - after_epoch=False, - use_cpickle=True, - every_n_epochs=config.save_freq), - ) - if config.sample_freq is not None: - extensions.append( - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True), - ) - if config.on_irc: - extensions.append( - ircext.IRCClientExt(m, config.sample_temperature, - server='irc.ulminfo.fr', - port=6667, - nick='frigo', - channels=['#frigotest', '#courssysteme'], - after_batch=True), + IRCClientExt(m, config.sample_temperature, + server='irc.ulminfo.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True) ) main_loop = MainLoop( |