summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-17 14:58:20 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-17 14:58:20 -0400
commit701c407b8c87a9270b31d34ac54e683341be661e (patch)
tree1ef80f0859a7f592d44e9f1c4010da02ee68fb10 /train.py
parente91e14e894196642532c0b7be50b01c1354ad702 (diff)
downloadtext-rnn-701c407b8c87a9270b31d34ac54e683341be661e.tar.gz
text-rnn-701c407b8c87a9270b31d34ac54e683341be661e.zip
xoxo
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py42
1 files changed, 26 insertions, 16 deletions
diff --git a/train.py b/train.py
index b59cd8e..79b2116 100755
--- a/train.py
+++ b/train.py
@@ -72,16 +72,36 @@ def train_model(m, train_stream, dump_path=None):
except IOError:
pass
+ extensions = []
+ if config.save_freq is not None:
+ extensions.append(
+ Checkpoint(path=dump_path,
+ after_epoch=False,
+ use_cpickle=True,
+ every_n_epochs=config.save_freq),
+ )
+ if config.sample_freq is not None:
+ extensions.append(
+ gentext.GenText(m, '\nalex\ttu crois ?\n',
+ config.sample_len, config.sample_temperature,
+ every_n_epochs=config.sample_freq,
+ after_epoch=False, before_training=True),
+ )
+ if config.on_irc:
+ extensions.append(
+ ircext.IRCClientExt(m, config.sample_temperature,
+ server='irc.ulminfo.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True),
+ )
+
main_loop = MainLoop(
model=model,
data_stream=train_stream,
algorithm=algorithm,
- extensions=[
- #Checkpoint(path=dump_path,
- # after_epoch=False,
- # use_cpickle=True,
- # every_n_epochs=config.save_freq),
-
+ extensions=extensions + [
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
@@ -92,16 +112,6 @@ def train_model(m, train_stream, dump_path=None):
server_url='http://eos21:4201/',
every_n_epochs=1, after_epoch=False),
- gentext.GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True),
- #ircext.IRCClientExt(m, config.sample_temperature,
- # server='irc.ulminfo.fr',
- # port=6667,
- # nick='frigo',
- # channels=['#frigotest', '#courssysteme'],
- # after_batch=True),
ResetStates([v for v, _ in m.states], after_epoch=True)
]
)