aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 13:23:28 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 13:25:33 -0400
commit5096e0cdae167122d07b09cd207a04f28ea5c3f5 (patch)
treeba15ca59dce8b301330b8ef2f282099e5f6991a2 /data/transformers.py
parent98139f573eb179c8f5a06ba6c8d8883376814ccf (diff)
downloadtaxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.tar.gz
taxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.zip
Add random seed for TaxiGenerateSplits and for fuel
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/data/transformers.py b/data/transformers.py
index e6806cc..239d957 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -1,8 +1,10 @@
import datetime
-import random
import numpy
import theano
+
+import fuel
+
from fuel.schemes import ConstantScheme
from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack
@@ -66,13 +68,15 @@ class TaxiGenerateSplits(Transformer):
self.id_latitude = data_stream.sources.index('latitude')
self.id_longitude = data_stream.sources.index('longitude')
+ self.rng = numpy.random.RandomState(fuel.config.default_seed)
+
def get_data(self, request=None):
if request is not None:
raise ValueError
while self.isplit >= len(self.splits):
self.data = next(self.child_epoch_iterator)
self.splits = range(len(self.data[self.id_longitude]))
- random.shuffle(self.splits)
+ self.rng.shuffle(self.splits)
if self.max_splits != -1 and len(self.splits) > self.max_splits:
self.splits = self.splits[:self.max_splits]
self.isplit = 0