aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
Diffstat (limited to 'model')
-rw-r--r--model/memory_network.py27
-rw-r--r--model/stream.py10
2 files changed, 23 insertions, 14 deletions
diff --git a/model/memory_network.py b/model/memory_network.py
index 7ced8c0..f47a12d 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -215,7 +215,7 @@ class StreamRecurrent(StreamBase):
]
self.candidate_inputs = self.prefix_inputs
- def candidate_stream(self, n_candidates):
+ def candidate_stream(self, n_candidates, sortmap=True):
candidate_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
if not data.tvt:
@@ -226,8 +226,14 @@ class StreamRecurrent(StreamBase):
if not data.tvt:
candidate_stream = transformers.add_destination(candidate_stream)
- candidate_stream = Batch(candidate_stream,
- iteration_scheme=ConstantScheme(n_candidates))
+ if sortmap:
+ candidate_stream = transformers.balanced_batch(candidate_stream,
+ key='latitude',
+ batch_size=n_candidates,
+ batch_sort_size=self.config.batch_sort_size)
+ else:
+ candidate_stream = Batch(candidate_stream,
+ iteration_scheme=ConstantScheme(n_candidates))
candidate_stream = Padding(candidate_stream,
mask_sources=['latitude', 'longitude'])
@@ -247,9 +253,9 @@ class StreamRecurrent(StreamBase):
prefix_stream = transformers.taxi_add_datetime(prefix_stream)
prefix_stream = transformers.balanced_batch(prefix_stream,
- key='latitude',
- batch_size=self.config.batch_size,
- batch_sort_size=self.config.batch_sort_size)
+ key='latitude',
+ batch_size=self.config.batch_size,
+ batch_sort_size=self.config.batch_sort_size)
prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude'])
@@ -271,8 +277,11 @@ class StreamRecurrent(StreamBase):
prefix_stream = transformers.taxi_add_datetime(prefix_stream)
- prefix_stream = Batch(prefix_stream,
- iteration_scheme=ConstantScheme(self.config.batch_size))
+ prefix_stream = transformers.balanced_batch(prefix_stream,
+ key='latitude',
+ batch_size=self.config.batch_size,
+ batch_sort_size=self.config.batch_sort_size)
+
prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude'])
candidate_stream = self.candidate_stream(self.config.valid_candidate_size)
@@ -298,7 +307,7 @@ class StreamRecurrent(StreamBase):
iteration_scheme=ConstantScheme(self.config.batch_size))
prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude'])
- candidate_stream = self.candidate_stream(self.config.test_candidate_size)
+ candidate_stream = self.candidate_stream(self.config.test_candidate_size, False)
sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
stream = Merge((prefix_stream, candidate_stream), sources)
diff --git a/model/stream.py b/model/stream.py
index 88b1d7f..61ff1c3 100644
--- a/model/stream.py
+++ b/model/stream.py
@@ -33,17 +33,13 @@ class StreamRec(object):
stream = transformers.TaxiExcludeEmptyTrips(stream)
stream = transformers.taxi_add_datetime(stream)
-
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
stream = transformers.balanced_batch(stream, key='latitude',
batch_size=self.config.batch_size,
batch_sort_size=self.config.batch_sort_size)
-
stream = Padding(stream, mask_sources=['latitude', 'longitude'])
-
stream = transformers.Select(stream, req_vars)
-
stream = MultiProcessing(stream)
return stream
@@ -54,9 +50,13 @@ class StreamRec(object):
stream = transformers.taxi_add_datetime(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+ stream = transformers.balanced_batch(stream, key='latitude',
+ batch_size=self.config.batch_size,
+ batch_sort_size=self.config.batch_sort_size)
stream = Padding(stream, mask_sources=['latitude', 'longitude'])
stream = transformers.Select(stream, req_vars)
+ stream = MultiProcessing(stream)
+
return stream
def test(self, req_vars):