From 99c37ecef356c20673b4ecd5030749e3a6abcf7a Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 20 Jul 2015 13:57:49 -0400 Subject: New model CCHLSTM ; other models broken. Please enter the commit message for your changes. Lines starting --- train.py | 51 ++++++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 23 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 61f6663..58bff1e 100755 --- a/train.py +++ b/train.py @@ -14,13 +14,18 @@ from theano.tensor.shared_randomstreams import RandomStreams from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER from blocks.extensions import Printing, SimpleExtension from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -from blocks.extras.extensions.plot import Plot from blocks.extensions.saveload import Checkpoint, Load from blocks.graph import ComputationGraph from blocks.main_loop import MainLoop from blocks.model import Model from blocks.algorithms import GradientDescent, StepRule, CompositeRule +try: + from blocks.extras.extensions.plot import Plot + plot_avail = True +except ImportError: + plot_avail = False + import datastream from paramsaveload import SaveLoadParams from gentext import GenText @@ -63,18 +68,34 @@ class ResetStates(SimpleExtension): def train_model(m, train_stream, dump_path=None): # Define the model - model = Model(m.cost) + model = Model(m.sgd_cost) - cg = ComputationGraph(m.cost_reg) - algorithm = GradientDescent(cost=m.cost_reg, + cg = ComputationGraph(m.sgd_cost) + algorithm = GradientDescent(cost=m.sgd_cost, step_rule=CompositeRule([ ElementwiseRemoveNotFinite(), config.step_rule]), - params=cg.parameters) + parameters=cg.parameters) algorithm.add_updates(m.states) - extensions = [] + monitor_vars = [v for p in m.monitor_vars for v in p] + extensions = [ + TrainingDataMonitoring( + monitor_vars, + prefix='train', every_n_epochs=1), + Printing(every_n_epochs=1, after_epoch=False), + + ResetStates([v for v, _ in m.states], after_epoch=True) + ] + if plot_avail: + plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars] + extensions.append( + Plot(document='text_'+model_name+'_'+config.param_desc, + channels=plot_channels, + server_url='http://eos6:5006/', + every_n_epochs=1, after_epoch=False) + ) if config.save_freq is not None and dump_path is not None: extensions.append( SaveLoadParams(path=dump_path+'.pkl', @@ -105,19 +126,7 @@ def train_model(m, train_stream, dump_path=None): model=model, data_stream=train_stream, algorithm=algorithm, - extensions=extensions + [ - TrainingDataMonitoring( - [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate], - prefix='train', every_n_epochs=1), - Printing(every_n_epochs=1, after_epoch=False), - Plot(document='text_'+model_name+'_'+config.param_desc, - channels=[['train_cost', 'train_cost_reg'], - ['train_error_rate', 'train_error_rate_reg']], - server_url='http://eos21:4201/', - every_n_epochs=1, after_epoch=False), - - ResetStates([v for v, _ in m.states], after_epoch=True) - ] + extensions=extensions ) main_loop.run() @@ -131,10 +140,6 @@ if __name__ == "__main__": # Build model m = config.Model() - m.cost.name = 'cost' - m.cost_reg.name = 'cost_reg' - m.error_rate.name = 'error_rate' - m.error_rate_reg.name = 'error_rate_reg' m.pred.name = 'pred' # Train the model -- cgit v1.2.3