summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2015-06-17 15:02:44 -0400
committerAlex Auvolat <alex@adnab.me>2015-06-17 15:02:44 -0400
commit0ba1bd24fd2375fc4de5d355e434f747c03de202 (patch)
tree17715acb9f482a4068d2ff5b5df26ac0be8216e8
parent12304944033d20bbc5c1b3f5cb90cf8dedebcdff (diff)
downloadtext-rnn-0ba1bd24fd2375fc4de5d355e434f747c03de202.tar.gz
text-rnn-0ba1bd24fd2375fc4de5d355e434f747c03de202.zip
SaveLoadParams
-rwxr-xr-xtrain.py51
1 files changed, 13 insertions, 38 deletions
diff --git a/train.py b/train.py
index 525724f..a188541 100755
--- a/train.py
+++ b/train.py
@@ -22,9 +22,9 @@ from blocks.model import Model
from blocks.algorithms import GradientDescent
import datastream
-import paramsaveload
-import gentext
-import ircext
+from paramsaveload import SaveLoadParams
+from gentext import GenText
+from ircext import IRCClientExt
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
@@ -72,44 +72,19 @@ def train_model(m, train_stream, dump_path=None):
)
if config.sample_freq is not None:
extensions.append(
- gentext.GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True)
+ GenText(m, '\nalex\ttu crois ?\n',
+ config.sample_len, config.sample_temperature,
+ every_n_epochs=config.sample_freq,
+ after_epoch=False, before_training=True)
)
if config.on_irc:
extensions.append(
- ircext.IRCClientExt(m, config.sample_temperature,
- server='irc.ulminfo.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True)
- )
-
- extensions = []
- if config.save_freq is not None:
- extensions.append(
- Checkpoint(path=dump_path,
- after_epoch=False,
- use_cpickle=True,
- every_n_epochs=config.save_freq),
- )
- if config.sample_freq is not None:
- extensions.append(
- gentext.GenText(m, '\nalex\ttu crois ?\n',
- config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True),
- )
- if config.on_irc:
- extensions.append(
- ircext.IRCClientExt(m, config.sample_temperature,
- server='irc.ulminfo.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True),
+ IRCClientExt(m, config.sample_temperature,
+ server='irc.ulminfo.fr',
+ port=6667,
+ nick='frigo',
+ channels=['#frigotest', '#courssysteme'],
+ after_batch=True)
)
main_loop = MainLoop(