diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-13 16:26:41 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-13 16:30:11 -0400 |
commit | 8470f64c9373308d7f85de0f7de3bdcbaf46ca0a (patch) | |
tree | 747f72e1fc9f521a3cfc08158ed373fa1f81b46b /train.py | |
parent | 1eca8867751df644a62752fbbfbc6a6de849de74 (diff) | |
download | taxi-8470f64c9373308d7f85de0f7de3bdcbaf46ca0a.tar.gz taxi-8470f64c9373308d7f85de0f7de3bdcbaf46ca0a.zip |
Add NaN protection
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 28 |
1 files changed, 18 insertions, 10 deletions
@@ -6,20 +6,24 @@ import importlib import csv +from picklable_itertools.extras import equizip + from blocks.model import Model from fuel.transformers import Batch from fuel.streams import DataStream from fuel.schemes import ConstantScheme, ShuffledExampleScheme -from blocks.algorithms import GradientDescent, AdaDelta, Momentum -from blocks.graph import ComputationGraph +from blocks.algorithms import CompositeRule, RemoveNotFinite, GradientDescent, AdaDelta, Momentum +from blocks.graph import ComputationGraph, apply_dropout from blocks.main_loop import MainLoop from blocks.extensions import Printing, FinishAfter from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring from blocks.extensions.plot import Plot +from theano import tensor + from data import transformers from data.hdf5 import TaxiDataset, TaxiStream import apply_model @@ -89,27 +93,31 @@ def main(): # Training cg = ComputationGraph(cost) + params = cg.parameters + algorithm = GradientDescent( cost=cost, - # step_rule=AdaDelta(decay_rate=0.5), - step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum), + step_rule=CompositeRule([ + RemoveNotFinite(), + #AdaDelta(decay_rate=0.95), + Momentum(learning_rate=config.learning_rate, momentum=config.momentum), + ]), params=params) plot_vars = [['valid_' + x.name for x in model.monitor]] - # plot_vars = ['valid_cost'] print "Plot: ", plot_vars extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000), DataStreamMonitoring(model.monitor, valid_stream, prefix='valid', - every_n_batches=1000), - Printing(every_n_batches=1000), - Plot(model_name, channels=plot_vars, every_n_batches=1000), + every_n_batches=500), + Printing(every_n_batches=500), + Plot(model_name, channels=plot_vars, every_n_batches=500), # Checkpoint('model.pkl', every_n_batches=100), - Dump('model_data/' + model_name, every_n_batches=1000), + Dump('model_data/' + model_name, every_n_batches=500), LoadFromDump('model_data/' + model_name), - # FinishAfter(after_epoch=42), + # FinishAfter(after_epoch=4), ] main_loop = MainLoop( |