From 0ba1bd24fd2375fc4de5d355e434f747c03de202 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 17 Jun 2015 15:02:44 -0400 Subject: SaveLoadParams --- train.py | 51 +++++++++++++-------------------------------------- 1 file changed, 13 insertions(+), 38 deletions(-) diff --git a/train.py b/train.py index 525724f..a188541 100755 --- a/train.py +++ b/train.py @@ -22,9 +22,9 @@ from blocks.model import Model from blocks.algorithms import GradientDescent import datastream -import paramsaveload -import gentext -import ircext +from paramsaveload import SaveLoadParams +from gentext import GenText +from ircext import IRCClientExt logging.basicConfig(level='INFO') logger = logging.getLogger(__name__) @@ -72,44 +72,19 @@ def train_model(m, train_stream, dump_path=None): ) if config.sample_freq is not None: extensions.append( - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True) + GenText(m, '\nalex\ttu crois ?\n', + config.sample_len, config.sample_temperature, + every_n_epochs=config.sample_freq, + after_epoch=False, before_training=True) ) if config.on_irc: extensions.append( - ircext.IRCClientExt(m, config.sample_temperature, - server='irc.ulminfo.fr', - port=6667, - nick='frigo', - channels=['#frigotest', '#courssysteme'], - after_batch=True) - ) - - extensions = [] - if config.save_freq is not None: - extensions.append( - Checkpoint(path=dump_path, - after_epoch=False, - use_cpickle=True, - every_n_epochs=config.save_freq), - ) - if config.sample_freq is not None: - extensions.append( - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True), - ) - if config.on_irc: - extensions.append( - ircext.IRCClientExt(m, config.sample_temperature, - server='irc.ulminfo.fr', - port=6667, - nick='frigo', - channels=['#frigotest', '#courssysteme'], - after_batch=True), + IRCClientExt(m, config.sample_temperature, + server='irc.ulminfo.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True) ) main_loop = MainLoop( -- cgit v1.2.3