summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/train.py b/train.py
index ddd1d0c..b59cd8e 100755
--- a/train.py
+++ b/train.py
@@ -23,11 +23,12 @@ from blocks.algorithms import GradientDescent
import datastream
import gentext
+import ircext
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
-sys.setrecursionlimit(1500)
+sys.setrecursionlimit(500000)
if __name__ == "__main__":
if len(sys.argv) != 2:
@@ -76,10 +77,10 @@ def train_model(m, train_stream, dump_path=None):
data_stream=train_stream,
algorithm=algorithm,
extensions=[
- Checkpoint(path=dump_path,
- after_epoch=False,
- use_cpickle=True,
- every_n_epochs=config.save_freq),
+ #Checkpoint(path=dump_path,
+ # after_epoch=False,
+ # use_cpickle=True,
+ # every_n_epochs=config.save_freq),
TrainingDataMonitoring(
[m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
@@ -95,6 +96,12 @@ def train_model(m, train_stream, dump_path=None):
config.sample_len, config.sample_temperature,
every_n_epochs=config.sample_freq,
after_epoch=False, before_training=True),
+ #ircext.IRCClientExt(m, config.sample_temperature,
+ # server='irc.ulminfo.fr',
+ # port=6667,
+ # nick='frigo',
+ # channels=['#frigotest', '#courssysteme'],
+ # after_batch=True),
ResetStates([v for v, _ in m.states], after_epoch=True)
]
)