diff options
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 44 |
1 files changed, 37 insertions, 7 deletions
@@ -5,16 +5,18 @@ import numpy import sys import importlib +import theano +from theano import tensor + from blocks.dump import load_parameter_values from blocks.dump import MainLoopDumpManager -from blocks.extensions import Printing +from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model from blocks.algorithms import GradientDescent -from theano import tensor import datastream # from apply_model import Apply @@ -30,6 +32,29 @@ if __name__ == "__main__": config = importlib.import_module('%s' % model_name) +class GenText(SimpleExtension): + def __init__(self, model, init_text, max_bytes, **kwargs): + self.init_text = init_text + self.max_bytes = max_bytes + + cg = ComputationGraph([model.pred]) + assert(len(cg.inputs) == 1) + assert(cg.inputs[0].name == 'bytes') + self.f = theano.function(inputs=cg.inputs, outputs=[model.pred]) + + super(GenText, self).__init__(**kwargs) + + def do(self, which_callback, *args): + v = numpy.array([ord(i) for i in self.init_text], + dtype='int16')[None, :].repeat(axis=0, repeats=config.num_seqs) + + while v.shape[1] < self.max_bytes: + pred, = self.f(v) + v = numpy.concatenate([v, pred[:, -1:]], axis=1) + + for i in range(v.shape[0]): + print "Sample:", ''.join([chr(int(v[i, j])) for j in range(v.shape[1])]) + def train_model(m, train_stream, load_location=None, save_location=None): # Define the model @@ -44,6 +69,9 @@ def train_model(m, train_stream, load_location=None, save_location=None): algorithm = GradientDescent(cost=m.cost_reg, step_rule=config.step_rule, params=cg.parameters) + + algorithm.add_updates(m.updates) + main_loop = MainLoop( model=model, data_stream=train_stream, @@ -51,12 +79,13 @@ def train_model(m, train_stream, load_location=None, save_location=None): extensions=[ TrainingDataMonitoring( [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], - prefix='train', every_n_epochs=1*config.pt_freq), - Printing(every_n_epochs=1*config.pt_freq, after_epoch=False), + prefix='train', every_n_epochs=1), + Printing(every_n_epochs=1, after_epoch=False), Plot(document='tr_'+model_name+'_'+config.param_desc, channels=[['train_cost', 'train_cost_reg'], ['train_error_rate', 'train_error_rate_reg']], - every_n_epochs=1*config.pt_freq, after_epoch=False) + every_n_epochs=1, after_epoch=False), + GenText(m, '\t', 20, every_n_epochs=1, after_epoch=False) ] ) main_loop.run() @@ -72,8 +101,9 @@ def train_model(m, train_stream, load_location=None, save_location=None): if __name__ == "__main__": # Build datastream train_stream = datastream.setup_datastream('data/logcompil.txt', - config.chars_per_seq, - config.seqs_per_epoch) + config.num_seqs, + config.seq_len, + config.seq_div_size) # Build model m = config.Model() |