diff options
-rwxr-xr-x | train.py | 51 |
1 files changed, 13 insertions, 38 deletions
@@ -22,9 +22,9 @@ from blocks.model import Model from blocks.algorithms import GradientDescent import datastream -import paramsaveload -import gentext -import ircext +from paramsaveload import SaveLoadParams +from gentext import GenText +from ircext import IRCClientExt logging.basicConfig(level='INFO') logger = logging.getLogger(__name__) @@ -72,44 +72,19 @@ def train_model(m, train_stream, dump_path=None): ) if config.sample_freq is not None: extensions.append( - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True) + GenText(m, '\nalex\ttu crois ?\n', + config.sample_len, config.sample_temperature, + every_n_epochs=config.sample_freq, + after_epoch=False, before_training=True) ) if config.on_irc: extensions.append( - ircext.IRCClientExt(m, config.sample_temperature, - server='irc.ulminfo.fr', - port=6667, - nick='frigo', - channels=['#frigotest', '#courssysteme'], - after_batch=True) - ) - - extensions = [] - if config.save_freq is not None: - extensions.append( - Checkpoint(path=dump_path, - after_epoch=False, - use_cpickle=True, - every_n_epochs=config.save_freq), - ) - if config.sample_freq is not None: - extensions.append( - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True), - ) - if config.on_irc: - extensions.append( - ircext.IRCClientExt(m, config.sample_temperature, - server='irc.ulminfo.fr', - port=6667, - nick='frigo', - channels=['#frigotest', '#courssysteme'], - after_batch=True), + IRCClientExt(m, config.sample_temperature, + server='irc.ulminfo.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True) ) main_loop = MainLoop( |