aboutsummaryrefslogtreecommitdiff
path: root/model/memory_network.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-24 16:09:48 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-24 16:09:48 -0400
commit7dab7e47ce0e8c5ae996821794450a9ad3186cd3 (patch)
treee0babcc305696a6e6a67a52acecd300bfdf22cf0 /model/memory_network.py
parent60e6bc64d8e3c6679a6e2a960513c656d481f0ed (diff)
downloadtaxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.tar.gz
taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.zip
Fix memory network
Diffstat (limited to 'model/memory_network.py')
-rw-r--r--model/memory_network.py58
1 files changed, 44 insertions, 14 deletions
diff --git a/model/memory_network.py b/model/memory_network.py
index 84a8edf..7ced8c0 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -25,12 +25,13 @@ class MemoryNetworkBase(Initializable):
self.children = [ self.softmax, prefix_encoder, candidate_encoder ]
self.inputs = self.prefix_encoder.apply.inputs \
- + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs]
+ + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] \
+ + ['candidate_destination_latitude', 'candidate_destination_longitude']
- def candidate_destination(**kwargs):
+ def candidate_destination(self, **kwargs):
return tensor.concatenate(
- (tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]),
- tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])),
+ (tensor.shape_padright(kwargs['candidate_destination_latitude']),
+ tensor.shape_padright(kwargs['candidate_destination_longitude'])),
axis=1)
@application(outputs=['cost'])
@@ -43,10 +44,8 @@ class MemoryNetworkBase(Initializable):
@application(outputs=['destination'])
def predict(self, **kwargs):
- prefix_representation = self.prefix_encoder.apply(
- { x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
- candidate_representatin = self.candidate_encoder.apply(
- { x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })
+ prefix_representation = self.prefix_encoder.apply(**{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
+ candidate_representation = self.candidate_encoder.apply(**{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })
if self.config.normalize_representation:
prefix_representation = prefix_representation \
@@ -130,12 +129,16 @@ class StreamSimple(StreamBase):
def candidate_stream(self, n_candidates):
candidate_stream = DataStream(self.train_dataset,
- iteration_scheme=ShuffledExampleScheme(dataset.num_examples))
- candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
+ iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
+ if not data.tvt:
+ candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
candidate_stream = transformers.taxi_add_datetime(candidate_stream)
candidate_stream = transformers.taxi_add_first_last_len(candidate_stream,
self.config.n_begin_end_pts)
+ if not data.tvt:
+ candidate_stream = transformers.add_destination(candidate_stream)
+
return Batch(candidate_stream,
iteration_scheme=ConstantScheme(n_candidates))
@@ -180,6 +183,27 @@ class StreamSimple(StreamBase):
stream = MultiProcessing(stream)
return stream
+ def test(self, req_vars):
+ prefix_stream = DataStream(
+ self.test_dataset,
+ iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))
+ prefix_stream = transformers.taxi_add_datetime(prefix_stream)
+ prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
+ self.config.n_begin_end_pts)
+
+ if not data.tvt:
+ prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
+
+ prefix_stream = Batch(prefix_stream,
+ iteration_scheme=ConstantScheme(self.config.batch_size))
+
+ candidate_stream = self.candidate_stream(self.config.test_candidate_size)
+
+ sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
+ stream = Merge((prefix_stream, candidate_stream), sources)
+ stream = transformers.Select(stream, tuple(req_vars))
+ stream = MultiProcessing(stream)
+ return stream
class StreamRecurrent(StreamBase):
def __init__(self, config):
@@ -194,10 +218,14 @@ class StreamRecurrent(StreamBase):
def candidate_stream(self, n_candidates):
candidate_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
- candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
+ if not data.tvt:
+ candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
candidate_stream = transformers.taxi_add_datetime(candidate_stream)
+ if not data.tvt:
+ candidate_stream = transformers.add_destination(candidate_stream)
+
candidate_stream = Batch(candidate_stream,
iteration_scheme=ConstantScheme(n_candidates))
@@ -210,7 +238,8 @@ class StreamRecurrent(StreamBase):
prefix_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
- prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
+ if not data.tvt:
+ prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
max_splits=self.config.max_splits)
@@ -238,7 +267,7 @@ class StreamRecurrent(StreamBase):
self.valid_dataset,
iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples))
- prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
+ #prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
prefix_stream = transformers.taxi_add_datetime(prefix_stream)
@@ -262,7 +291,8 @@ class StreamRecurrent(StreamBase):
iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))
prefix_stream = transformers.taxi_add_datetime(prefix_stream)
- prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
+ if not data.tvt:
+ prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
prefix_stream = Batch(prefix_stream,
iteration_scheme=ConstantScheme(self.config.batch_size))