diff options
Diffstat (limited to 'model/memory_network.py')
-rw-r--r-- | model/memory_network.py | 58 |
1 files changed, 44 insertions, 14 deletions
diff --git a/model/memory_network.py b/model/memory_network.py index 84a8edf..7ced8c0 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -25,12 +25,13 @@ class MemoryNetworkBase(Initializable): self.children = [ self.softmax, prefix_encoder, candidate_encoder ] self.inputs = self.prefix_encoder.apply.inputs \ - + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] + + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] \ + + ['candidate_destination_latitude', 'candidate_destination_longitude'] - def candidate_destination(**kwargs): + def candidate_destination(self, **kwargs): return tensor.concatenate( - (tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]), - tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])), + (tensor.shape_padright(kwargs['candidate_destination_latitude']), + tensor.shape_padright(kwargs['candidate_destination_longitude'])), axis=1) @application(outputs=['cost']) @@ -43,10 +44,8 @@ class MemoryNetworkBase(Initializable): @application(outputs=['destination']) def predict(self, **kwargs): - prefix_representation = self.prefix_encoder.apply( - { x: kwargs[x] for x in self.prefix_encoder.apply.inputs }) - candidate_representatin = self.candidate_encoder.apply( - { x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs }) + prefix_representation = self.prefix_encoder.apply(**{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs }) + candidate_representation = self.candidate_encoder.apply(**{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs }) if self.config.normalize_representation: prefix_representation = prefix_representation \ @@ -130,12 +129,16 @@ class StreamSimple(StreamBase): def candidate_stream(self, n_candidates): candidate_stream = DataStream(self.train_dataset, - iteration_scheme=ShuffledExampleScheme(dataset.num_examples)) - candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) + iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) + if not data.tvt: + candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream) candidate_stream = transformers.taxi_add_datetime(candidate_stream) candidate_stream = transformers.taxi_add_first_last_len(candidate_stream, self.config.n_begin_end_pts) + if not data.tvt: + candidate_stream = transformers.add_destination(candidate_stream) + return Batch(candidate_stream, iteration_scheme=ConstantScheme(n_candidates)) @@ -180,6 +183,27 @@ class StreamSimple(StreamBase): stream = MultiProcessing(stream) return stream + def test(self, req_vars): + prefix_stream = DataStream( + self.test_dataset, + iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples)) + prefix_stream = transformers.taxi_add_datetime(prefix_stream) + prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, + self.config.n_begin_end_pts) + + if not data.tvt: + prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) + + prefix_stream = Batch(prefix_stream, + iteration_scheme=ConstantScheme(self.config.batch_size)) + + candidate_stream = self.candidate_stream(self.config.test_candidate_size) + + sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) + stream = Merge((prefix_stream, candidate_stream), sources) + stream = transformers.Select(stream, tuple(req_vars)) + stream = MultiProcessing(stream) + return stream class StreamRecurrent(StreamBase): def __init__(self, config): @@ -194,10 +218,14 @@ class StreamRecurrent(StreamBase): def candidate_stream(self, n_candidates): candidate_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) + if not data.tvt: + candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream) candidate_stream = transformers.taxi_add_datetime(candidate_stream) + if not data.tvt: + candidate_stream = transformers.add_destination(candidate_stream) + candidate_stream = Batch(candidate_stream, iteration_scheme=ConstantScheme(n_candidates)) @@ -210,7 +238,8 @@ class StreamRecurrent(StreamBase): prefix_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) + if not data.tvt: + prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) @@ -238,7 +267,7 @@ class StreamRecurrent(StreamBase): self.valid_dataset, iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) + #prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.taxi_add_datetime(prefix_stream) @@ -262,7 +291,8 @@ class StreamRecurrent(StreamBase): iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples)) prefix_stream = transformers.taxi_add_datetime(prefix_stream) - prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) + if not data.tvt: + prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme(self.config.batch_size)) |