summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-17 09:23:47 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-17 09:23:47 -0400
commite91e14e894196642532c0b7be50b01c1354ad702 (patch)
treefa87266d998dceb1df8e37b30c71fbccc291d25f /train.py
parent211c2272c544ab0bbf7b87b374736a71c790ac8e (diff)
downloadtext-rnn-e91e14e894196642532c0b7be50b01c1354ad702.tar.gz
text-rnn-e91e14e894196642532c0b7be50b01c1354ad702.zip
Connect it to IRC ; add GFGRU model
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/train.py b/train.py
index ddd1d0c..b59cd8e 100755
--- a/train.py
+++ b/train.py
@@ -23,11 +23,12 @@ from blocks.algorithms import GradientDescent
import datastream
import gentext
+import ircext
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
-sys.setrecursionlimit(1500)
+sys.setrecursionlimit(500000)
if __name__ == "__main__":
if len(sys.argv) != 2:
@@ -76,10 +77,10 @@ def train_model(m, train_stream, dump_path=None):
data_stream=train_stream,
algorithm=algorithm,
extensions=[
- Checkpoint(path=dump_path,
- after_epoch=False,
- use_cpickle=True,
- every_n_epochs=config.save_freq),
+ #Checkpoint(path=dump_path,
+ # after_epoch=False,
+ # use_cpickle=True,
+ # every_n_epochs=config.save_freq),
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
@@ -95,6 +96,12 @@ def train_model(m, train_stream, dump_path=None):
config.sample_len, config.sample_temperature,
every_n_epochs=config.sample_freq,
after_epoch=False, before_training=True),
+ #ircext.IRCClientExt(m, config.sample_temperature,
+ # server='irc.ulminfo.fr',
+ # port=6667,
+ # nick='frigo',
+ # channels=['#frigotest', '#courssysteme'],
+ # after_batch=True),
ResetStates([v for v, _ in m.states], after_epoch=True)
]
)