aboutsummaryrefslogtreecommitdiff
path: root/model/mlp.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/mlp.py')
-rw-r--r--model/mlp.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/model/mlp.py b/model/mlp.py
index 7d04c82..d24b2cc 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -106,7 +106,7 @@ class Stream(object):
stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.taxi_remove_test_only_clients(stream)
- return Batch(stream, iteration_scheme=ConstantScheme(1))
+ return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
def inputs(self):
return {'call_type': tensor.bvector('call_type'),