aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-13 16:26:41 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-13 16:30:11 -0400
commit8470f64c9373308d7f85de0f7de3bdcbaf46ca0a (patch)
tree747f72e1fc9f521a3cfc08158ed373fa1f81b46b
parent1eca8867751df644a62752fbbfbc6a6de849de74 (diff)
downloadtaxi-8470f64c9373308d7f85de0f7de3bdcbaf46ca0a.tar.gz
taxi-8470f64c9373308d7f85de0f7de3bdcbaf46ca0a.zip
Add NaN protection
-rwxr-xr-xtrain.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/train.py b/train.py
index a70bb90..96dd798 100755
--- a/train.py
+++ b/train.py
@@ -6,20 +6,24 @@ import importlib
import csv
+from picklable_itertools.extras import equizip
+
from blocks.model import Model
from fuel.transformers import Batch
from fuel.streams import DataStream
from fuel.schemes import ConstantScheme, ShuffledExampleScheme
-from blocks.algorithms import GradientDescent, AdaDelta, Momentum
-from blocks.graph import ComputationGraph
+from blocks.algorithms import CompositeRule, RemoveNotFinite, GradientDescent, AdaDelta, Momentum
+from blocks.graph import ComputationGraph, apply_dropout
from blocks.main_loop import MainLoop
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.extensions.plot import Plot
+from theano import tensor
+
from data import transformers
from data.hdf5 import TaxiDataset, TaxiStream
import apply_model
@@ -89,27 +93,31 @@ def main():
# Training
cg = ComputationGraph(cost)
+
params = cg.parameters
+
algorithm = GradientDescent(
cost=cost,
- # step_rule=AdaDelta(decay_rate=0.5),
- step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
+ step_rule=CompositeRule([
+ RemoveNotFinite(),
+ #AdaDelta(decay_rate=0.95),
+ Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
+ ]),
params=params)
plot_vars = [['valid_' + x.name for x in model.monitor]]
- # plot_vars = ['valid_cost']
print "Plot: ", plot_vars
extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000),
DataStreamMonitoring(model.monitor, valid_stream,
prefix='valid',
- every_n_batches=1000),
- Printing(every_n_batches=1000),
- Plot(model_name, channels=plot_vars, every_n_batches=1000),
+ every_n_batches=500),
+ Printing(every_n_batches=500),
+ Plot(model_name, channels=plot_vars, every_n_batches=500),
# Checkpoint('model.pkl', every_n_batches=100),
- Dump('model_data/' + model_name, every_n_batches=1000),
+ Dump('model_data/' + model_name, every_n_batches=500),
LoadFromDump('model_data/' + model_name),
- # FinishAfter(after_epoch=42),
+ # FinishAfter(after_epoch=4),
]
main_loop = MainLoop(