aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/memory_network.py')
-rw-r--r--model/memory_network.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/model/memory_network.py b/model/memory_network.py
index 92e83e2..5acbfe3 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -88,7 +88,7 @@ class Stream(object):
dataset = TaxiDataset('train')
- prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme())
+ prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme(self.config.num_cuts))
prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, valid_trips_ids)
prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits)
prefix_stream = transformers.taxi_add_datetime(prefix_stream)