diff options
Diffstat (limited to 'data/transformers.py')
-rw-r--r-- | data/transformers.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/data/transformers.py b/data/transformers.py index e6806cc..239d957 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -1,8 +1,10 @@ import datetime -import random import numpy import theano + +import fuel + from fuel.schemes import ConstantScheme from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack @@ -66,13 +68,15 @@ class TaxiGenerateSplits(Transformer): self.id_latitude = data_stream.sources.index('latitude') self.id_longitude = data_stream.sources.index('longitude') + self.rng = numpy.random.RandomState(fuel.config.default_seed) + def get_data(self, request=None): if request is not None: raise ValueError while self.isplit >= len(self.splits): self.data = next(self.child_epoch_iterator) self.splits = range(len(self.data[self.id_longitude])) - random.shuffle(self.splits) + self.rng.shuffle(self.splits) if self.max_splits != -1 and len(self.splits) > self.max_splits: self.splits = self.splits[:self.max_splits] self.isplit = 0 |