aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-06 10:40:41 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-06 10:40:41 -0400
commit793be7b049cecba43072858341dc7006fef352e7 (patch)
tree3bfa24eef625b5f22368a40a446c91a738e0e5a0 /model
parent389d8001be77e6cacb35804236fe9d3f0930282b (diff)
downloadtaxi-793be7b049cecba43072858341dc7006fef352e7.tar.gz
taxi-793be7b049cecba43072858341dc7006fef352e7.zip
Add batch shuffle preprocessing step
Diffstat (limited to 'model')
-rw-r--r--model/mlp.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/model/mlp.py b/model/mlp.py
index 7e7d092..1f53e8c 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -1,9 +1,10 @@
from theano import tensor
+import numpy
import fuel
import blocks
-from fuel.transformers import Batch, MultiProcessing
+from fuel.transformers import Batch, MultiProcessing, Mapping, SortMapping, Unpack
from fuel.streams import DataStream
from fuel.schemes import ConstantScheme, ShuffledExampleScheme
from blocks.bricks import application, MLP, Rectifier, Initializable
@@ -73,6 +74,12 @@ class Stream(object):
stream = transformers.taxi_add_datetime(stream)
stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.Select(stream, tuple(req_vars))
+
+ if hasattr(self.config, 'shuffle_batch_size'):
+ stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size))
+ rng = numpy.random.RandomState(123)
+ stream = Mapping(stream, SortMapping(lambda x: float(rng.uniform())))
+ stream = Unpack(stream)
stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))