summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py51
1 files changed, 13 insertions, 38 deletions
diff --git a/train.py b/train.py
index 525724f..a188541 100755
--- a/train.py
+++ b/train.py
@@ -22,9 +22,9 @@ from blocks.model import Model
from blocks.algorithms import GradientDescent
import datastream
-import paramsaveload
-import gentext
-import ircext
+from paramsaveload import SaveLoadParams
+from gentext import GenText
+from ircext import IRCClientExt
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
@@ -72,44 +72,19 @@ def train_model(m, train_stream, dump_path=None):
)
if config.sample_freq is not None:
extensions.append(
- gentext.GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True)
+ GenText(m, '\nalex\ttu crois ?\n',
+ config.sample_len, config.sample_temperature,
+ every_n_epochs=config.sample_freq,
+ after_epoch=False, before_training=True)
)
if config.on_irc:
extensions.append(
- ircext.IRCClientExt(m, config.sample_temperature,
- server='irc.ulminfo.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True)
- )
-
- extensions = []
- if config.save_freq is not None:
- extensions.append(
- Checkpoint(path=dump_path,
- after_epoch=False,
- use_cpickle=True,
- every_n_epochs=config.save_freq),
- )
- if config.sample_freq is not None:
- extensions.append(
- gentext.GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True),
- )
- if config.on_irc:
- extensions.append(
- ircext.IRCClientExt(m, config.sample_temperature,
- server='irc.ulminfo.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True),
+ IRCClientExt(m, config.sample_temperature,
+ server='irc.ulminfo.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True)
)
main_loop = MainLoop(