summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2016-03-08 13:26:28 +0100
committerAlex Auvolat <alex@adnab.me>2016-03-08 13:26:28 +0100
commit2f479926c16d2911d0dd878c21de082abfc5b237 (patch)
treeb399e9ad9af04a9449334dff1a47449808b7ca13 /train.py
parent23093608e0edc43477c3a2ed804ae1016790f7e4 (diff)
downloadtext-rnn-2f479926c16d2911d0dd878c21de082abfc5b237.tar.gz
text-rnn-2f479926c16d2911d0dd878c21de082abfc5b237.zip
Revive project
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py120
1 files changed, 45 insertions, 75 deletions
diff --git a/train.py b/train.py
index 58bff1e..09555f2 100755
--- a/train.py
+++ b/train.py
@@ -1,58 +1,37 @@
-#!/usr/bin/env python
+#!/usr/bin/env python2
import logging
-import numpy
import sys
import importlib
-from contextlib import closing
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
import theano
-from theano import tensor
-from theano.tensor.shared_randomstreams import RandomStreams
-from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER
-from blocks.extensions import Printing, SimpleExtension
+from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-from blocks.extensions.saveload import Checkpoint, Load
+
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
-from blocks.algorithms import GradientDescent, StepRule, CompositeRule
+from blocks.algorithms import GradientDescent
try:
from blocks.extras.extensions.plot import Plot
plot_avail = True
except ImportError:
plot_avail = False
+ logger.warning('Plotting extension not available')
+
import datastream
from paramsaveload import SaveLoadParams
from gentext import GenText
-from ircext import IRCClientExt
-logging.basicConfig(level='INFO')
-logger = logging.getLogger(__name__)
sys.setrecursionlimit(500000)
-if __name__ == "__main__":
- if len(sys.argv) != 2:
- print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
- sys.exit(1)
- model_name = sys.argv[1]
- config = importlib.import_module('%s' % model_name)
-
-
-class ElementwiseRemoveNotFinite(StepRule):
- def __init__(self, scaler=0.1):
- self.scaler = scaler
-
- def compute_step(self, param, previous_step):
- not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step)
- step = tensor.switch(not_finite, self.scaler * param, previous_step)
-
- return step, []
class ResetStates(SimpleExtension):
def __init__(self, state_vars, **kwargs):
@@ -65,84 +44,75 @@ class ResetStates(SimpleExtension):
def do(self, which_callback, *args):
self.f()
-def train_model(m, train_stream, dump_path=None):
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
+ sys.exit(1)
+ model_name = sys.argv[-1]
+ config = importlib.import_module('.%s' % model_name, 'config')
+
+ # Build datastream
+ train_stream = datastream.setup_datastream(config.dataset,
+ config.num_seqs,
+ config.seq_len,
+ config.seq_div_size)
- # Define the model
- model = Model(m.sgd_cost)
+ # Build model
+ m = config.Model(config)
- cg = ComputationGraph(m.sgd_cost)
+ # Train the model
+ cg = Model(m.sgd_cost)
algorithm = GradientDescent(cost=m.sgd_cost,
- step_rule=CompositeRule([
- ElementwiseRemoveNotFinite(),
- config.step_rule]),
+ step_rule=config.step_rule,
parameters=cg.parameters)
algorithm.add_updates(m.states)
- monitor_vars = [v for p in m.monitor_vars for v in p]
+ monitor_vars = list(set(v for p in m.monitor_vars for v in p))
extensions = [
TrainingDataMonitoring(
monitor_vars,
- prefix='train', every_n_epochs=1),
- Printing(every_n_epochs=1, after_epoch=False),
+ prefix='train', every_n_batches=config.monitor_freq),
+ Printing(every_n_batches=config.monitor_freq, after_epoch=False),
+ ProgressBar(),
ResetStates([v for v, _ in m.states], after_epoch=True)
]
if plot_avail:
plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
extensions.append(
- Plot(document='text_'+model_name+'_'+config.param_desc,
+ Plot(document='text_'+model_name,
channels=plot_channels,
- server_url='http://eos6:5006/',
- every_n_epochs=1, after_epoch=False)
+ # server_url='http://localhost:5006',
+ every_n_batches=config.monitor_freq)
)
- if config.save_freq is not None and dump_path is not None:
+
+ if config.save_freq is not None and not '--nosave' in sys.argv:
extensions.append(
- SaveLoadParams(path=dump_path+'.pkl',
- model=model,
- before_training=True,
+ SaveLoadParams(path='params/%s.pkl'%model_name,
+ model=cg,
+ before_training=(not '--noload' in sys.argv),
after_training=True,
- after_epoch=False,
- every_n_epochs=config.save_freq)
+ every_n_batches=config.save_freq)
)
+
if config.sample_freq is not None:
extensions.append(
- GenText(m, '\nalex\ttu crois ?\n',
+ GenText(m, config.sample_init,
config.sample_len, config.sample_temperature,
- every_n_epochs=config.sample_freq,
- after_epoch=False, before_training=True)
- )
- if config.on_irc:
- extensions.append(
- IRCClientExt(m, config.sample_temperature,
- server='irc.ulminfo.fr',
- port=6667,
- nick='frigo',
- channels=['#frigotest', '#courssysteme'],
- after_batch=True)
+ before_training=True,
+ every_n_batches=config.sample_freq)
)
main_loop = MainLoop(
- model=model,
+ model=cg,
data_stream=train_stream,
algorithm=algorithm,
extensions=extensions
)
main_loop.run()
+ main_loop.profile.report()
-if __name__ == "__main__":
- # Build datastream
- train_stream = datastream.setup_datastream('data/logcompil.txt',
- config.num_seqs,
- config.seq_len,
- config.seq_div_size)
-
- # Build model
- m = config.Model()
- m.pred.name = 'pred'
-
- # Train the model
- saveloc = 'model_data/%s-%s' % (model_name, config.param_desc)
- train_model(m, train_stream, dump_path=saveloc)
+# vim: set sts=4 ts=4 sw=4 tw=0 et :