aboutsummaryrefslogtreecommitdiff
path: root/model/mlp.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/mlp.py')
-rw-r--r--model/mlp.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/model/mlp.py b/model/mlp.py
index 576b45b..05898a5 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -1,6 +1,6 @@
from theano import tensor
-from fuel.transformers import Batch
+from fuel.transformers import Batch, MultiProcessing
from fuel.streams import DataStream
from fuel.schemes import ConstantScheme, ShuffledExampleScheme
from blocks.bricks import application, MLP, Rectifier, Initializable
@@ -63,7 +63,10 @@ class Stream(object):
stream = transformers.TaxiAddDateTime(stream)
stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
stream = transformers.Select(stream, tuple(req_vars))
- return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+
+ return stream
def valid(self, req_vars):
stream = TaxiStream(self.config.valid_set, 'valid.hdf5')