diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-24 16:09:48 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-24 16:09:48 -0400 |
commit | 7dab7e47ce0e8c5ae996821794450a9ad3186cd3 (patch) | |
tree | e0babcc305696a6e6a67a52acecd300bfdf22cf0 /model | |
parent | 60e6bc64d8e3c6679a6e2a960513c656d481f0ed (diff) | |
download | taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.tar.gz taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.zip |
Fix memory network
Diffstat (limited to 'model')
-rw-r--r-- | model/memory_network.py | 58 | ||||
-rw-r--r-- | model/memory_network_bidir.py | 17 | ||||
-rw-r--r-- | model/memory_network_mlp.py | 28 |
3 files changed, 68 insertions, 35 deletions
diff --git a/model/memory_network.py b/model/memory_network.py index 84a8edf..7ced8c0 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -25,12 +25,13 @@ class MemoryNetworkBase(Initializable): self.children = [ self.softmax, prefix_encoder, candidate_encoder ] self.inputs = self.prefix_encoder.apply.inputs \ - + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] + + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] \ + + ['candidate_destination_latitude', 'candidate_destination_longitude'] - def candidate_destination(**kwargs): + def candidate_destination(self, **kwargs): return tensor.concatenate( - (tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]), - tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])), + (tensor.shape_padright(kwargs['candidate_destination_latitude']), + tensor.shape_padright(kwargs['candidate_destination_longitude'])), axis=1) @application(outputs=['cost']) @@ -43,10 +44,8 @@ class MemoryNetworkBase(Initializable): @application(outputs=['destination']) def predict(self, **kwargs): - prefix_representation = self.prefix_encoder.apply( - { x: kwargs[x] for x in self.prefix_encoder.apply.inputs }) - candidate_representatin = self.candidate_encoder.apply( - { x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs }) + prefix_representation = self.prefix_encoder.apply(**{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs }) + candidate_representation = self.candidate_encoder.apply(**{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs }) if self.config.normalize_representation: prefix_representation = prefix_representation \ @@ -130,12 +129,16 @@ class StreamSimple(StreamBase): def candidate_stream(self, n_candidates): candidate_stream = DataStream(self.train_dataset, - iteration_scheme=ShuffledExampleScheme(dataset.num_examples)) - candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) + iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) + if not data.tvt: + candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream) candidate_stream = transformers.taxi_add_datetime(candidate_stream) candidate_stream = transformers.taxi_add_first_last_len(candidate_stream, self.config.n_begin_end_pts) + if not data.tvt: + candidate_stream = transformers.add_destination(candidate_stream) + return Batch(candidate_stream, iteration_scheme=ConstantScheme(n_candidates)) @@ -180,6 +183,27 @@ class StreamSimple(StreamBase): stream = MultiProcessing(stream) return stream + def test(self, req_vars): + prefix_stream = DataStream( + self.test_dataset, + iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples)) + prefix_stream = transformers.taxi_add_datetime(prefix_stream) + prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, + self.config.n_begin_end_pts) + + if not data.tvt: + prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) + + prefix_stream = Batch(prefix_stream, + iteration_scheme=ConstantScheme(self.config.batch_size)) + + candidate_stream = self.candidate_stream(self.config.test_candidate_size) + + sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) + stream = Merge((prefix_stream, candidate_stream), sources) + stream = transformers.Select(stream, tuple(req_vars)) + stream = MultiProcessing(stream) + return stream class StreamRecurrent(StreamBase): def __init__(self, config): @@ -194,10 +218,14 @@ class StreamRecurrent(StreamBase): def candidate_stream(self, n_candidates): candidate_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) + if not data.tvt: + candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream) candidate_stream = transformers.taxi_add_datetime(candidate_stream) + if not data.tvt: + candidate_stream = transformers.add_destination(candidate_stream) + candidate_stream = Batch(candidate_stream, iteration_scheme=ConstantScheme(n_candidates)) @@ -210,7 +238,8 @@ class StreamRecurrent(StreamBase): prefix_stream = DataStream(self.train_dataset, iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) + if not data.tvt: + prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) @@ -238,7 +267,7 @@ class StreamRecurrent(StreamBase): self.valid_dataset, iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples)) - prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) + #prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) prefix_stream = transformers.taxi_add_datetime(prefix_stream) @@ -262,7 +291,8 @@ class StreamRecurrent(StreamBase): iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples)) prefix_stream = transformers.taxi_add_datetime(prefix_stream) - prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) + if not data.tvt: + prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme(self.config.batch_size)) diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py index cc99312..81e6440 100644 --- a/model/memory_network_bidir.py +++ b/model/memory_network_bidir.py @@ -72,22 +72,25 @@ class RecurrentEncoder(Initializable): return outputs + @apply.property('inputs') + def apply_inputs(self): + return self.inputs + class Model(MemoryNetworkBase): def __init__(self, config, **kwargs): # Build prefix encoder : recurrent then MLP - prefix_encoder = RecurrentEncoder(self.config.prefix_encoder, - self.config.representation_size, - self.config.representation_activation(), + prefix_encoder = RecurrentEncoder(config.prefix_encoder, + config.representation_size, + config.representation_activation(), name='prefix_encoder') # Build candidate encoder - candidate_encoder = RecurrentEncoder(self.config.candidate_encoder, - self.config.representation_size, - self.config.representation_activation(), + candidate_encoder = RecurrentEncoder(config.candidate_encoder, + config.representation_size, + config.representation_activation(), name='candidate_encoder') # And... that's it! super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs) - diff --git a/model/memory_network_mlp.py b/model/memory_network_mlp.py index de07e60..fc897d5 100644 --- a/model/memory_network_mlp.py +++ b/model/memory_network_mlp.py @@ -18,17 +18,17 @@ from memory_network import MemoryNetworkBase class MLPEncoder(Initializable): def __init__(self, config, output_dim, activation, **kwargs): - super(RecurrentEncoder, self).__init__(**kwargs) + super(MLPEncoder, self).__init__(**kwargs) self.config = config self.context_embedder = ContextEmbedder(self.config) - self.encoder_mlp = MLP(activations=[Rectifier() for _ in config.prefix_encoder.dim_hidden] - + [config.representation_activation()], - dims=[config.prefix_encoder.dim_input] - + config.prefix_encoder.dim_hidden - + [config.representation_size], - name='prefix_encoder') + self.encoder_mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + + [activation()], + dims=[config.dim_input] + + config.dim_hidden + + [output_dim], + name='encoder') self.extremities = {'%s_k_%s' % (side, ['latitude', 'longitude'][axis]): axis for side in ['first', 'last'] for axis in [0, 1]} @@ -37,7 +37,7 @@ class MLPEncoder(Initializable): self.encoder_mlp ] def _push_initialization_config(self): - for brick in [self.contex_encoder, self.encoder_mlp]: + for brick in [self.context_embedder, self.encoder_mlp]: brick.weights_init = self.config.weights_init brick.biases_init = self.config.biases_init @@ -46,7 +46,7 @@ class MLPEncoder(Initializable): embeddings = tuple(self.context_embedder.apply( **{k: kwargs[k] for k in self.context_embedder.inputs })) extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v] - for k, v in self.prefix_extremities.items()) + for k, v in self.extremities.items()) inputs = tensor.concatenate(extremities + embeddings, axis=1) return self.encoder_mlp.apply(inputs) @@ -60,12 +60,12 @@ class Model(MemoryNetworkBase): def __init__(self, config, **kwargs): prefix_encoder = MLPEncoder(config.prefix_encoder, config.representation_size, - config.representation_activation()) + config.representation_activation, + name='prefix_encoder') - candidate_encoer = MLPEncoder(config.candidate_encoder, + candidate_encoder = MLPEncoder(config.candidate_encoder, config.representation_size, - config.representation_activation()) + config.representation_activation, + name='candidate_encoder') super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs) - - |