#!/usr/bin/env python2
import logging
import sys
import importlib
import theano
from blocks.extensions import Printing, SimpleExtension, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.algorithms import GradientDescent
try:
from blocks.extras.extensions.plot import Plot
plot_avail = False
except ImportError:
plot_avail = False
import datastream
from paramsaveload import SaveLoadParams
from gentext import GenText
from ircext import IRCClientExt
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
sys.setrecursionlimit(500000)
class ResetStates(SimpleExtension):
def __init__(self, state_vars, **kwargs):
super(ResetStates, self).__init__(**kwargs)
self.f = theano.function(
inputs=[], outputs=[],
updates=[(v, v.zeros_like()) for v in state_vars])
def do(self, which_callback, *args):
self.f()
if __name__ == "__main__":
if len(sys.argv) < 2:
print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
sys.exit(1)
model_name = sys.argv[-1]
config = importlib.import_module('%s' % model_name)
# Build datastream
train_stream = datastream.setup_datastream('data/logcompil.txt',
config.num_seqs,
config.seq_len,
config.seq_div_size)
# Build model
m = config.Model()
m.pred.name = 'pred'
# Train the model
saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
train_model(m, train_stream, dump_path=saveloc)
# Define the model
model = Model(m.sgd_cost)
# IRC mode : just load the parameters and run an IRC server
if '--irc' in sys.argv:
try:
extensions.append(FinishAfter(before_training=True, after_n_batches=1))
print "Initializing main loop"
main_loop.run()
print "Jumping into IRC"
irc.run_forever()
except KeyboardInterrupt:
pass
sys.exit(0)
# Train the model
cg = ComputationGraph(m.sgd_cost)
algorithm = GradientDescent(cost=m.sgd_cost,
step_rule=config.step_rule,
parameters=cg.parameters)
algorithm.add_updates(m.states)
monitor_vars = [v for p in m.monitor_vars for v in p]
extensions = [
TrainingDataMonitoring(
monitor_vars,
prefix='train', every_n_epochs=1),
Printing(every_n_epochs=1, after_epoch=False),
ResetStates([v for v, _ in m.states], after_epoch=True)
]
if plot_avail:
plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
extensions.append(
Plot(document='text_'+model_name,
channels=plot_channels,
server_url='http://localhost:5006',
every_n_epochs=1, after_epoch=False)
)
if config.save_freq is not None and dump_path is not None:
extensions.append(
SaveLoadParams(path=dump_path+'.pkl',
model=model,
before_training=True,
after_training=True,
after_epoch=False,
every_n_epochs=config.save_freq)
)
if config.sample_freq is not None:
extensions.append(
GenText(m, '\nalex\ttu crois ?\n',
config.sample_len, config.sample_temperature,
every_n_epochs=config.sample_freq,
after_epoch=False, before_training=True)
)
if config.on_irc:
irc = IRCClientExt(m, config.sample_temperature,
server='clipper.ens.fr',
port=6667,
nick='frigo',
channels=['#frigotest', '#courssysteme'],
after_batch=True)
irc.do('before_training')
extensions.append(irc)
if config.on_irc:
irc = IRCClientExt(m, config.sample_temperature,
server='clipper.ens.fr',
port=6667,
nick='frigo',
channels=['#frigotest', '#courssysteme'],
after_batch=True)
irc.do('before_training')
extensions.append(irc)
main_loop = MainLoop(
model=model,
data_stream=train_stream,
algorithm=algorithm,
extensions=extensions
)
main_loop.run()
# IRC mode : just load the parameters and run an IRC server
if '--irc' in sys.argv:
try:
extensions.append(FinishAfter(before_training=True, after_n_batches=1))
print "Initializing main loop"
main_loop.run()
print "Jumping into IRC"
irc.run_forever()
except KeyboardInterrupt:
pass
sys.exit(0)
# vim: set sts=4 ts=4 sw=4 tw=0 et :