diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 13:23:28 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 13:25:33 -0400 |
commit | 5096e0cdae167122d07b09cd207a04f28ea5c3f5 (patch) | |
tree | ba15ca59dce8b301330b8ef2f282099e5f6991a2 /data | |
parent | 98139f573eb179c8f5a06ba6c8d8883376814ccf (diff) | |
download | taxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.tar.gz taxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.zip |
Add random seed for TaxiGenerateSplits and for fuel
Diffstat (limited to 'data')
-rw-r--r-- | data/transformers.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/data/transformers.py b/data/transformers.py index e6806cc..239d957 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -1,8 +1,10 @@ import datetime -import random import numpy import theano + +import fuel + from fuel.schemes import ConstantScheme from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack @@ -66,13 +68,15 @@ class TaxiGenerateSplits(Transformer): self.id_latitude = data_stream.sources.index('latitude') self.id_longitude = data_stream.sources.index('longitude') + self.rng = numpy.random.RandomState(fuel.config.default_seed) + def get_data(self, request=None): if request is not None: raise ValueError while self.isplit >= len(self.splits): self.data = next(self.child_epoch_iterator) self.splits = range(len(self.data[self.id_longitude])) - random.shuffle(self.splits) + self.rng.shuffle(self.splits) if self.max_splits != -1 and len(self.splits) > self.max_splits: self.splits = self.splits[:self.max_splits] self.isplit = 0 |