diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-09 13:42:08 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-09 13:42:08 -0400 |
commit | 6a5ed2e43a5885eeb3c5e202ed5bb473f6065401 (patch) | |
tree | f5541c64fe0fc4eaf3567d716eb71c2dc5fb1ea8 /train.py | |
download | text-rnn-6a5ed2e43a5885eeb3c5e202ed5bb473f6065401.tar.gz text-rnn-6a5ed2e43a5885eeb3c5e202ed5bb473f6065401.zip |
First commit
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/train.py b/train.py new file mode 100755 index 0000000..ab973a1 --- /dev/null +++ b/train.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python + +import logging +import numpy +import sys +import importlib + +from blocks.dump import load_parameter_values +from blocks.dump import MainLoopDumpManager +from blocks.extensions import Printing +from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring +from blocks.extensions.plot import Plot +from blocks.graph import ComputationGraph +from blocks.main_loop import MainLoop +from blocks.model import Model +from blocks.algorithms import GradientDescent +from theano import tensor + +import datastream +# from apply_model import Apply + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, 'Usage: %s config' % sys.argv[0] + sys.exit(1) + model_name = sys.argv[1] + config = importlib.import_module('%s' % model_name) + + +def train_model(m, train_stream, load_location=None, save_location=None): + + # Define the model + model = Model(m.cost) + + # Load the parameters from a dumped model + if load_location is not None: + logger.info('Loading parameters...') + model.set_param_values(load_parameter_values(load_location)) + + cg = ComputationGraph(m.cost_reg) + algorithm = GradientDescent(cost=m.cost_reg, + step_rule=config.step_rule, + params=cg.parameters) + main_loop = MainLoop( + model=model, + data_stream=train_stream, + algorithm=algorithm, + extensions=[ + TrainingDataMonitoring( + [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], + prefix='train', every_n_epochs=1*config.pt_freq), + Printing(every_n_epochs=1*config.pt_freq, after_epoch=False), + Plot(document='tr_'+model_name+'_'+config.param_desc, + channels=[['train_cost', 'train_cost_reg'], + ['train_error_rate', 'train_error_rate_reg']], + every_n_epochs=1*config.pt_freq, after_epoch=False) + ] + ) + main_loop.run() + + # Save the main loop + if save_location is not None: + logger.info('Saving the main loop...') + dump_manager = MainLoopDumpManager(save_location) + dump_manager.dump(main_loop) + logger.info('Saved') + + +if __name__ == "__main__": + # Build datastream + train_stream = datastream.setup_datastream('data/logcompil.txt', + config.chars_per_seq, + config.seqs_per_epoch) + + # Build model + m = config.Model() + m.cost.name = 'cost' + m.cost_reg.name = 'cost_reg' + m.error_rate.name = 'error_rate' + m.error_rate_reg.name = 'error_rate_reg' + m.pred.name = 'pred' + + # Train the model + saveloc = 'model_data/%s' % model_name + train_model(m, train_stream, + load_location=None, + save_location=None) + |