From 5096e0cdae167122d07b09cd207a04f28ea5c3f5 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 2 Jul 2015 13:23:28 -0400 Subject: Add random seed for TaxiGenerateSplits and for fuel --- data/transformers.py | 8 ++++++-- train.py | 18 +++++++++++------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/data/transformers.py b/data/transformers.py index e6806cc..239d957 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -1,8 +1,10 @@ import datetime -import random import numpy import theano + +import fuel + from fuel.schemes import ConstantScheme from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack @@ -66,13 +68,15 @@ class TaxiGenerateSplits(Transformer): self.id_latitude = data_stream.sources.index('latitude') self.id_longitude = data_stream.sources.index('longitude') + self.rng = numpy.random.RandomState(fuel.config.default_seed) + def get_data(self, request=None): if request is not None: raise ValueError while self.isplit >= len(self.splits): self.data = next(self.child_epoch_iterator) self.splits = range(len(self.data[self.id_longitude])) - random.shuffle(self.splits) + self.rng.shuffle(self.splits) if self.max_splits != -1 and len(self.splits) > self.max_splits: self.splits = self.splits[:self.max_splits] self.isplit = 0 diff --git a/train.py b/train.py index 6d3f37b..77dca53 100755 --- a/train.py +++ b/train.py @@ -9,12 +9,16 @@ from functools import reduce from theano import tensor +import blocks +import fuel + from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -import blocks + blocks.config.default_seed = 123 +fuel.config.default_seed = 123 try: from blocks.extras.extensions.plotting import Plot @@ -104,12 +108,12 @@ if __name__ == "__main__": every_n_batches=1000), Printing(every_n_batches=1000), - SaveLoadParams(dump_path, cg, - before_training=True, # before training -> load params - every_n_batches=1000, # every N batches -> save params - after_epoch=True, # after epoch -> save params - after_training=True, # after training -> save params - ), + # SaveLoadParams(dump_path, cg, + # before_training=True, # before training -> load params + # every_n_batches=1000, # every N batches -> save params + # after_epoch=True, # after epoch -> save params + # after_training=True, # after training -> save params + # ), RunOnTest(model_name, model, -- cgit v1.2.3