diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-25 17:31:17 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-25 17:31:17 -0400 |
commit | 2a20bc827a8c1c9b6e74ef4e1234788207be45b8 (patch) | |
tree | 2051f9f81324bbf202a3bc3fa8c59a68288d2566 /model/memory_network.py | |
parent | d3dbbe68642320b9a225b1d3515f0181916df8ad (diff) | |
download | taxi-2a20bc827a8c1c9b6e74ef4e1234788207be45b8.tar.gz taxi-2a20bc827a8c1c9b6e74ef4e1234788207be45b8.zip |
Add batch sorting
Diffstat (limited to 'model/memory_network.py')
-rw-r--r-- | model/memory_network.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/model/memory_network.py b/model/memory_network.py index 7ced8c0..f47a12d 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -215,7 +215,7 @@ class StreamRecurrent(StreamBase): ] self.candidate_inputs = self.prefix_inputs - def candidate_stream(self, n_candidates): + def candidate_stream(self, n_candidates, sortmap=True): candidate_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) if not data.tvt: @@ -226,8 +226,14 @@ class StreamRecurrent(StreamBase): if not data.tvt: candidate_stream = transformers.add_destination(candidate_stream) - candidate_stream = Batch(candidate_stream, - iteration_scheme=ConstantScheme(n_candidates)) + if sortmap: + candidate_stream = transformers.balanced_batch(candidate_stream, + key='latitude', + batch_size=n_candidates, + batch_sort_size=self.config.batch_sort_size) + else: + candidate_stream = Batch(candidate_stream, + iteration_scheme=ConstantScheme(n_candidates)) candidate_stream = Padding(candidate_stream, mask_sources=['latitude', 'longitude']) @@ -247,9 +253,9 @@ class StreamRecurrent(StreamBase): prefix_stream = transformers.taxi_add_datetime(prefix_stream) prefix_stream = transformers.balanced_batch(prefix_stream, - key='latitude', - batch_size=self.config.batch_size, - batch_sort_size=self.config.batch_sort_size) + key='latitude', + batch_size=self.config.batch_size, + batch_sort_size=self.config.batch_sort_size) prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude']) @@ -271,8 +277,11 @@ class StreamRecurrent(StreamBase): prefix_stream = transformers.taxi_add_datetime(prefix_stream) - prefix_stream = Batch(prefix_stream, - iteration_scheme=ConstantScheme(self.config.batch_size)) + prefix_stream = transformers.balanced_batch(prefix_stream, + key='latitude', + batch_size=self.config.batch_size, + batch_sort_size=self.config.batch_sort_size) + prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude']) candidate_stream = self.candidate_stream(self.config.valid_candidate_size) @@ -298,7 +307,7 @@ class StreamRecurrent(StreamBase): iteration_scheme=ConstantScheme(self.config.batch_size)) prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude']) - candidate_stream = self.candidate_stream(self.config.test_candidate_size) + candidate_stream = self.candidate_stream(self.config.test_candidate_size, False) sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) stream = Merge((prefix_stream, candidate_stream), sources) |