diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 14:58:20 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-17 14:58:20 -0400 |
commit | 701c407b8c87a9270b31d34ac54e683341be661e (patch) | |
tree | 1ef80f0859a7f592d44e9f1c4010da02ee68fb10 /train.py | |
parent | e91e14e894196642532c0b7be50b01c1354ad702 (diff) | |
download | text-rnn-701c407b8c87a9270b31d34ac54e683341be661e.tar.gz text-rnn-701c407b8c87a9270b31d34ac54e683341be661e.zip |
xoxo
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 42 |
1 files changed, 26 insertions, 16 deletions
@@ -72,16 +72,36 @@ def train_model(m, train_stream, dump_path=None): except IOError: pass + extensions = [] + if config.save_freq is not None: + extensions.append( + Checkpoint(path=dump_path, + after_epoch=False, + use_cpickle=True, + every_n_epochs=config.save_freq), + ) + if config.sample_freq is not None: + extensions.append( + gentext.GenText(m, '\nalex\ttu crois ?\n', + config.sample_len, config.sample_temperature, + every_n_epochs=config.sample_freq, + after_epoch=False, before_training=True), + ) + if config.on_irc: + extensions.append( + ircext.IRCClientExt(m, config.sample_temperature, + server='irc.ulminfo.fr', + port=6667, + nick='frigo', + channels=['#frigotest', '#courssysteme'], + after_batch=True), + ) + main_loop = MainLoop( model=model, data_stream=train_stream, algorithm=algorithm, - extensions=[ - #Checkpoint(path=dump_path, - # after_epoch=False, - # use_cpickle=True, - # every_n_epochs=config.save_freq), - + extensions=extensions + [ TrainingDataMonitoring( [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], prefix='train', every_n_epochs=1), @@ -92,16 +112,6 @@ def train_model(m, train_stream, dump_path=None): server_url='http://eos21:4201/', every_n_epochs=1, after_epoch=False), - gentext.GenText(m, '\nalex\ttu crois ?\n', - config.sample_len, config.sample_temperature, - every_n_epochs=config.sample_freq, - after_epoch=False, before_training=True), - #ircext.IRCClientExt(m, config.sample_temperature, - # server='irc.ulminfo.fr', - # port=6667, - # nick='frigo', - # channels=['#frigotest', '#courssysteme'], - # after_batch=True), ResetStates([v for v, _ in m.states], after_epoch=True) ] ) |