From 5096e0cdae167122d07b09cd207a04f28ea5c3f5 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 2 Jul 2015 13:23:28 -0400 Subject: Add random seed for TaxiGenerateSplits and for fuel --- train.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 6d3f37b..77dca53 100755 --- a/train.py +++ b/train.py @@ -9,12 +9,16 @@ from functools import reduce from theano import tensor +import blocks +import fuel + from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -import blocks + blocks.config.default_seed = 123 +fuel.config.default_seed = 123 try: from blocks.extras.extensions.plotting import Plot @@ -104,12 +108,12 @@ if __name__ == "__main__": every_n_batches=1000), Printing(every_n_batches=1000), - SaveLoadParams(dump_path, cg, - before_training=True, # before training -> load params - every_n_batches=1000, # every N batches -> save params - after_epoch=True, # after epoch -> save params - after_training=True, # after training -> save params - ), + # SaveLoadParams(dump_path, cg, + # before_training=True, # before training -> load params + # every_n_batches=1000, # every N batches -> save params + # after_epoch=True, # after epoch -> save params + # after_training=True, # after training -> save params + # ), RunOnTest(model_name, model, -- cgit v1.2.3