From 8470f64c9373308d7f85de0f7de3bdcbaf46ca0a Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 13 May 2015 16:26:41 -0400 Subject: Add NaN protection --- train.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index a70bb90..96dd798 100755 --- a/train.py +++ b/train.py @@ -6,20 +6,24 @@ import importlib import csv +from picklable_itertools.extras import equizip + from blocks.model import Model from fuel.transformers import Batch from fuel.streams import DataStream from fuel.schemes import ConstantScheme, ShuffledExampleScheme -from blocks.algorithms import GradientDescent, AdaDelta, Momentum -from blocks.graph import ComputationGraph +from blocks.algorithms import CompositeRule, RemoveNotFinite, GradientDescent, AdaDelta, Momentum +from blocks.graph import ComputationGraph, apply_dropout from blocks.main_loop import MainLoop from blocks.extensions import Printing, FinishAfter from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot +from theano import tensor + from data import transformers from data.hdf5 import TaxiDataset, TaxiStream import apply_model @@ -89,27 +93,31 @@ def main(): # Training cg = ComputationGraph(cost) + params = cg.parameters + algorithm = GradientDescent( cost=cost, - # step_rule=AdaDelta(decay_rate=0.5), - step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum), + step_rule=CompositeRule([ + RemoveNotFinite(), + #AdaDelta(decay_rate=0.95), + Momentum(learning_rate=config.learning_rate, momentum=config.momentum), + ]), params=params) plot_vars = [['valid_' + x.name for x in model.monitor]] - # plot_vars = ['valid_cost'] print "Plot: ", plot_vars extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000), DataStreamMonitoring(model.monitor, valid_stream, prefix='valid', - every_n_batches=1000), - Printing(every_n_batches=1000), - Plot(model_name, channels=plot_vars, every_n_batches=1000), + every_n_batches=500), + Printing(every_n_batches=500), + Plot(model_name, channels=plot_vars, every_n_batches=500), # Checkpoint('model.pkl', every_n_batches=100), - Dump('model_data/' + model_name, every_n_batches=1000), + Dump('model_data/' + model_name, every_n_batches=500), LoadFromDump('model_data/' + model_name), - # FinishAfter(after_epoch=42), + # FinishAfter(after_epoch=4), ] main_loop = MainLoop( -- cgit v1.2.3