diff options
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 51 |
1 files changed, 28 insertions, 23 deletions
@@ -14,13 +14,18 @@ from theano.tensor.shared_randomstreams import RandomStreams from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -from blocks.extras.extensions.plot import Plot from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model from blocks.algorithms import GradientDescent, StepRule, CompositeRule +try: + from blocks.extras.extensions.plot import Plot + plot_avail = True +except ImportError: + plot_avail = False + import datastream from paramsaveload import SaveLoadParams from gentext import GenText @@ -63,18 +68,34 @@ class ResetStates(SimpleExtension): def train_model(m, train_stream, dump_path=None): # Define the model - model = Model(m.cost) + model = Model(m.sgd_cost) - cg = ComputationGraph(m.cost_reg) - algorithm = GradientDescent(cost=m.cost_reg, + cg = ComputationGraph(m.sgd_cost) + algorithm = GradientDescent(cost=m.sgd_cost, step_rule=CompositeRule([ ElementwiseRemoveNotFinite(), config.step_rule]), - params=cg.parameters) + parameters=cg.parameters) algorithm.add_updates(m.states) - extensions = [] + monitor_vars = [v for p in m.monitor_vars for v in p] + extensions = [ + TrainingDataMonitoring( + monitor_vars, + prefix='train', every_n_epochs=1), + Printing(every_n_epochs=1, after_epoch=False), + + ResetStates([v for v, _ in m.states], after_epoch=True) + ] + if plot_avail: + plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars] + extensions.append( + Plot(document='text_'+model_name+'_'+config.param_desc, + channels=plot_channels, + server_url='http://eos6:5006/', + every_n_epochs=1, after_epoch=False) + ) if config.save_freq is not None and dump_path is not None: extensions.append( SaveLoadParams(path=dump_path+'.pkl', @@ -105,19 +126,7 @@ def train_model(m, train_stream, dump_path=None): model=model, data_stream=train_stream, algorithm=algorithm, - extensions=extensions + [ - TrainingDataMonitoring( - [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], - prefix='train', every_n_epochs=1), - Printing(every_n_epochs=1, after_epoch=False), - Plot(document='text_'+model_name+'_'+config.param_desc, - channels=[['train_cost', 'train_cost_reg'], - ['train_error_rate', 'train_error_rate_reg']], - server_url='http://eos21:4201/', - every_n_epochs=1, after_epoch=False), - - ResetStates([v for v, _ in m.states], after_epoch=True) - ] + extensions=extensions ) main_loop.run() @@ -131,10 +140,6 @@ if __name__ == "__main__": # Build model m = config.Model() - m.cost.name = 'cost' - m.cost_reg.name = 'cost_reg' - m.error_rate.name = 'error_rate' - m.error_rate_reg.name = 'error_rate_reg' m.pred.name = 'pred' # Train the model |