diff options
Diffstat (limited to 'model/mlp.py')
-rw-r--r-- | model/mlp.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/model/mlp.py b/model/mlp.py index 576b45b..05898a5 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -1,6 +1,6 @@ from theano import tensor -from fuel.transformers import Batch +from fuel.transformers import Batch, MultiProcessing from fuel.streams import DataStream from fuel.schemes import ConstantScheme, ShuffledExampleScheme from blocks.bricks import application, MLP, Rectifier, Initializable @@ -63,7 +63,10 @@ class Stream(object): stream = transformers.TaxiAddDateTime(stream) stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) stream = transformers.Select(stream, tuple(req_vars)) - return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + + stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + + return stream def valid(self, req_vars): stream = TaxiStream(self.config.valid_set, 'valid.hdf5') |