summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py42
1 files changed, 26 insertions, 16 deletions
diff --git a/train.py b/train.py
index b59cd8e..79b2116 100755
--- a/train.py
+++ b/train.py
@@ -72,16 +72,36 @@ def train_model(m, train_stream, dump_path=None):
except IOError:
pass
+ extensions = []
+ if config.save_freq is not None:
+ extensions.append(
+ Checkpoint(path=dump_path,
+ after_epoch=False,
+ use_cpickle=True,
+ every_n_epochs=config.save_freq),
+ )
+ if config.sample_freq is not None:
+ extensions.append(
+ gentext.GenText(m, '\nalex\ttu crois ?\n',
+ config.sample_len, config.sample_temperature,
+ every_n_epochs=config.sample_freq,
+ after_epoch=False, before_training=True),
+ )
+ if config.on_irc:
+ extensions.append(
+ ircext.IRCClientExt(m, config.sample_temperature,
+ server='irc.ulminfo.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True),
+ )
+
main_loop = MainLoop(
model=model,
data_stream=train_stream,
algorithm=algorithm,
- extensions=[
- #Checkpoint(path=dump_path,
- # after_epoch=False,
- # use_cpickle=True,
- # every_n_epochs=config.save_freq),
-
+ extensions=extensions + [
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
prefix='train', every_n_epochs=1),
@@ -92,16 +112,6 @@ def train_model(m, train_stream, dump_path=None):
server_url='http://eos21:4201/',
every_n_epochs=1, after_epoch=False),
- gentext.GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True),
- #ircext.IRCClientExt(m, config.sample_temperature,
- # server='irc.ulminfo.fr',
- # port=6667,
- # nick='frigo',
- # channels=['#frigotest', '#courssysteme'],
- # after_batch=True),
ResetStates([v for v, _ in m.states], after_epoch=True)
]
)