aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
Diffstat (limited to 'model')
-rw-r--r--model/memory_network.py58
-rw-r--r--model/memory_network_bidir.py17
-rw-r--r--model/memory_network_mlp.py28
3 files changed, 68 insertions, 35 deletions
diff --git a/model/memory_network.py b/model/memory_network.py
index 84a8edf..7ced8c0 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -25,12 +25,13 @@ class MemoryNetworkBase(Initializable):
self.children = [ self.softmax, prefix_encoder, candidate_encoder ]
self.inputs = self.prefix_encoder.apply.inputs \
- + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs]
+ + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] \
+ + ['candidate_destination_latitude', 'candidate_destination_longitude']
- def candidate_destination(**kwargs):
+ def candidate_destination(self, **kwargs):
return tensor.concatenate(
- (tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]),
- tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])),
+ (tensor.shape_padright(kwargs['candidate_destination_latitude']),
+ tensor.shape_padright(kwargs['candidate_destination_longitude'])),
axis=1)
@application(outputs=['cost'])
@@ -43,10 +44,8 @@ class MemoryNetworkBase(Initializable):
@application(outputs=['destination'])
def predict(self, **kwargs):
- prefix_representation = self.prefix_encoder.apply(
- { x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
- candidate_representatin = self.candidate_encoder.apply(
- { x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })
+ prefix_representation = self.prefix_encoder.apply(**{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
+ candidate_representation = self.candidate_encoder.apply(**{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })
if self.config.normalize_representation:
prefix_representation = prefix_representation \
@@ -130,12 +129,16 @@ class StreamSimple(StreamBase):
def candidate_stream(self, n_candidates):
candidate_stream = DataStream(self.train_dataset,
- iteration_scheme=ShuffledExampleScheme(dataset.num_examples))
- candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
+ iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
+ if not data.tvt:
+ candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
candidate_stream = transformers.taxi_add_datetime(candidate_stream)
candidate_stream = transformers.taxi_add_first_last_len(candidate_stream,
self.config.n_begin_end_pts)
+ if not data.tvt:
+ candidate_stream = transformers.add_destination(candidate_stream)
+
return Batch(candidate_stream,
iteration_scheme=ConstantScheme(n_candidates))
@@ -180,6 +183,27 @@ class StreamSimple(StreamBase):
stream = MultiProcessing(stream)
return stream
+ def test(self, req_vars):
+ prefix_stream = DataStream(
+ self.test_dataset,
+ iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))
+ prefix_stream = transformers.taxi_add_datetime(prefix_stream)
+ prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
+ self.config.n_begin_end_pts)
+
+ if not data.tvt:
+ prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
+
+ prefix_stream = Batch(prefix_stream,
+ iteration_scheme=ConstantScheme(self.config.batch_size))
+
+ candidate_stream = self.candidate_stream(self.config.test_candidate_size)
+
+ sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
+ stream = Merge((prefix_stream, candidate_stream), sources)
+ stream = transformers.Select(stream, tuple(req_vars))
+ stream = MultiProcessing(stream)
+ return stream
class StreamRecurrent(StreamBase):
def __init__(self, config):
@@ -194,10 +218,14 @@ class StreamRecurrent(StreamBase):
def candidate_stream(self, n_candidates):
candidate_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
- candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
+ if not data.tvt:
+ candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
candidate_stream = transformers.taxi_add_datetime(candidate_stream)
+ if not data.tvt:
+ candidate_stream = transformers.add_destination(candidate_stream)
+
candidate_stream = Batch(candidate_stream,
iteration_scheme=ConstantScheme(n_candidates))
@@ -210,7 +238,8 @@ class StreamRecurrent(StreamBase):
prefix_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
- prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
+ if not data.tvt:
+ prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
max_splits=self.config.max_splits)
@@ -238,7 +267,7 @@ class StreamRecurrent(StreamBase):
self.valid_dataset,
iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples))
- prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
+ #prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
prefix_stream = transformers.taxi_add_datetime(prefix_stream)
@@ -262,7 +291,8 @@ class StreamRecurrent(StreamBase):
iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))
prefix_stream = transformers.taxi_add_datetime(prefix_stream)
- prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
+ if not data.tvt:
+ prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
prefix_stream = Batch(prefix_stream,
iteration_scheme=ConstantScheme(self.config.batch_size))
diff --git a/model/memory_network_bidir.py b/model/memory_network_bidir.py
index cc99312..81e6440 100644
--- a/model/memory_network_bidir.py
+++ b/model/memory_network_bidir.py
@@ -72,22 +72,25 @@ class RecurrentEncoder(Initializable):
return outputs
+ @apply.property('inputs')
+ def apply_inputs(self):
+ return self.inputs
+
class Model(MemoryNetworkBase):
def __init__(self, config, **kwargs):
# Build prefix encoder : recurrent then MLP
- prefix_encoder = RecurrentEncoder(self.config.prefix_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
+ prefix_encoder = RecurrentEncoder(config.prefix_encoder,
+ config.representation_size,
+ config.representation_activation(),
name='prefix_encoder')
# Build candidate encoder
- candidate_encoder = RecurrentEncoder(self.config.candidate_encoder,
- self.config.representation_size,
- self.config.representation_activation(),
+ candidate_encoder = RecurrentEncoder(config.candidate_encoder,
+ config.representation_size,
+ config.representation_activation(),
name='candidate_encoder')
# And... that's it!
super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)
-
diff --git a/model/memory_network_mlp.py b/model/memory_network_mlp.py
index de07e60..fc897d5 100644
--- a/model/memory_network_mlp.py
+++ b/model/memory_network_mlp.py
@@ -18,17 +18,17 @@ from memory_network import MemoryNetworkBase
class MLPEncoder(Initializable):
def __init__(self, config, output_dim, activation, **kwargs):
- super(RecurrentEncoder, self).__init__(**kwargs)
+ super(MLPEncoder, self).__init__(**kwargs)
self.config = config
self.context_embedder = ContextEmbedder(self.config)
- self.encoder_mlp = MLP(activations=[Rectifier() for _ in config.prefix_encoder.dim_hidden]
- + [config.representation_activation()],
- dims=[config.prefix_encoder.dim_input]
- + config.prefix_encoder.dim_hidden
- + [config.representation_size],
- name='prefix_encoder')
+ self.encoder_mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden]
+ + [activation()],
+ dims=[config.dim_input]
+ + config.dim_hidden
+ + [output_dim],
+ name='encoder')
self.extremities = {'%s_k_%s' % (side, ['latitude', 'longitude'][axis]): axis
for side in ['first', 'last'] for axis in [0, 1]}
@@ -37,7 +37,7 @@ class MLPEncoder(Initializable):
self.encoder_mlp ]
def _push_initialization_config(self):
- for brick in [self.contex_encoder, self.encoder_mlp]:
+ for brick in [self.context_embedder, self.encoder_mlp]:
brick.weights_init = self.config.weights_init
brick.biases_init = self.config.biases_init
@@ -46,7 +46,7 @@ class MLPEncoder(Initializable):
embeddings = tuple(self.context_embedder.apply(
**{k: kwargs[k] for k in self.context_embedder.inputs }))
extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v]
- for k, v in self.prefix_extremities.items())
+ for k, v in self.extremities.items())
inputs = tensor.concatenate(extremities + embeddings, axis=1)
return self.encoder_mlp.apply(inputs)
@@ -60,12 +60,12 @@ class Model(MemoryNetworkBase):
def __init__(self, config, **kwargs):
prefix_encoder = MLPEncoder(config.prefix_encoder,
config.representation_size,
- config.representation_activation())
+ config.representation_activation,
+ name='prefix_encoder')
- candidate_encoer = MLPEncoder(config.candidate_encoder,
+ candidate_encoder = MLPEncoder(config.candidate_encoder,
config.representation_size,
- config.representation_activation())
+ config.representation_activation,
+ name='candidate_encoder')
super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)
-
-